将 Spark 数据框字符串列拆分为多列

我已经看到很多人建议 Dataframe.explode是一种有用的方法,但是它会导致比原始数据框架更多的行,这根本不是我想要的。我只是想在数据框架中做一些非常简单的事情:

rdd.map(lambda row: row + [row.my_str_col.split('-')])

看起来就像:

col1 | my_str_col
-----+-----------
18 |  856-yygrm
201 |  777-psgdg

然后转换成这个:

col1 | my_str_col | _col3 | _col4
-----+------------+-------+------
18 |  856-yygrm |   856 | yygrm
201 |  777-psgdg |   777 | psgdg

我知道 pyspark.sql.functions.split(),但是它导致嵌套的数组列,而不是我想要的两个顶级列。

理想情况下,我希望这些新列也被命名。

211341 次浏览

pyspark.sql.functions.split() is the right approach here - you simply need to flatten the nested ArrayType column into multiple top-level columns. In this case, where each array only contains 2 items, it's very easy. You simply use Column.getItem() to retrieve each part of the array as a column itself:

split_col = pyspark.sql.functions.split(df['my_str_col'], '-')
df = df.withColumn('NAME1', split_col.getItem(0))
df = df.withColumn('NAME2', split_col.getItem(1))

The result will be:

col1 | my_str_col | NAME1 | NAME2
-----+------------+-------+------
18 |  856-yygrm |   856 | yygrm
201 |  777-psgdg |   777 | psgdg

I am not sure how I would solve this in a general case where the nested arrays were not the same size from Row to Row.

Here's a solution to the general case that doesn't involve needing to know the length of the array ahead of time, using collect, or using udfs. Unfortunately this only works for spark version 2.1 and above, because it requires the posexplode function.

Suppose you had the following DataFrame:

df = spark.createDataFrame(
[
[1, 'A, B, C, D'],
[2, 'E, F, G'],
[3, 'H, I'],
[4, 'J']
]
, ["num", "letters"]
)
df.show()
#+---+----------+
#|num|   letters|
#+---+----------+
#|  1|A, B, C, D|
#|  2|   E, F, G|
#|  3|      H, I|
#|  4|         J|
#+---+----------+

Split the letters column and then use posexplode to explode the resultant array along with the position in the array. Next use pyspark.sql.functions.expr to grab the element at index pos in this array.

import pyspark.sql.functions as f


df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)\
.show()
#+---+------------+---+---+
#|num|     letters|pos|val|
#+---+------------+---+---+
#|  1|[A, B, C, D]|  0|  A|
#|  1|[A, B, C, D]|  1|  B|
#|  1|[A, B, C, D]|  2|  C|
#|  1|[A, B, C, D]|  3|  D|
#|  2|   [E, F, G]|  0|  E|
#|  2|   [E, F, G]|  1|  F|
#|  2|   [E, F, G]|  2|  G|
#|  3|      [H, I]|  0|  H|
#|  3|      [H, I]|  1|  I|
#|  4|         [J]|  0|  J|
#+---+------------+---+---+

Now we create two new columns from this result. First one is the name of our new column, which will be a concatenation of letter and the index in the array. The second column will be the value at the corresponding index in the array. We get the latter by exploiting the functionality of pyspark.sql.functions.expr which allows us use column values as parameters.

df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)\
.drop("val")\
.select(
"num",
f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
f.expr("letters[pos]").alias("val")
)\
.show()
#+---+-------+---+
#|num|   name|val|
#+---+-------+---+
#|  1|letter0|  A|
#|  1|letter1|  B|
#|  1|letter2|  C|
#|  1|letter3|  D|
#|  2|letter0|  E|
#|  2|letter1|  F|
#|  2|letter2|  G|
#|  3|letter0|  H|
#|  3|letter1|  I|
#|  4|letter0|  J|
#+---+-------+---+

Now we can just groupBy the num and pivot the DataFrame. Putting that all together, we get:

df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)\
.drop("val")\
.select(
"num",
f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
f.expr("letters[pos]").alias("val")
)\
.groupBy("num").pivot("name").agg(f.first("val"))\
.show()
#+---+-------+-------+-------+-------+
#|num|letter0|letter1|letter2|letter3|
#+---+-------+-------+-------+-------+
#|  1|      A|      B|      C|      D|
#|  3|      H|      I|   null|   null|
#|  2|      E|      F|      G|   null|
#|  4|      J|   null|   null|   null|
#+---+-------+-------+-------+-------+

Here's another approach, in case you want split a string with a delimiter.

import pyspark.sql.functions as f


df = spark.createDataFrame([("1:a:2001",),("2:b:2002",),("3:c:2003",)],["value"])
df.show()
+--------+
|   value|
+--------+
|1:a:2001|
|2:b:2002|
|3:c:2003|
+--------+


df_split = df.select(f.split(df.value,":")).rdd.flatMap(
lambda x: x).toDF(schema=["col1","col2","col3"])


df_split.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1|   a|2001|
|   2|   b|2002|
|   3|   c|2003|
+----+----+----+

I don't think this transition back and forth to RDDs is going to slow you down... Also don't worry about last schema specification: it's optional, you can avoid it generalizing the solution to data with unknown column size.

I understand your pain. Using split() can work, but can also lead to breaks.

Let's take your df and make a slight change to it:

df = spark.createDataFrame([('1:"a:3":2001',),('2:"b":2002',),('3:"c":2003',)],["value"])


df.show()


+------------+
|       value|
+------------+
|1:"a:3":2001|
|  2:"b":2002|
|  3:"c":2003|
+------------+

If you try to apply split() to this as outlined above:

df_split = df.select(split(df.value,":")).rdd.flatMap(
lambda x: x).toDF(schema=["col1","col2","col3"]).show()

you will get

IllegalStateException: Input row doesn't have expected number of values required by the schema. 4 fields are required while 3 values are provided.

So, is there a more elegant way of addressing this? I was so happy to have it pointed out to me. pyspark.sql.functions.from_csv() is your friend.

Taking my above example df:

from pyspark.sql.functions import from_csv


# Define a column schema to apply with from_csv()
col_schema = ["col1 INTEGER","col2 STRING","col3 INTEGER"]
schema_str = ",".join(col_schema)


# define the separator because it isn't a ','
options = {'sep': ":"}


# create a df from the value column using schema and options
df_csv = df.select(from_csv(df.value, schema_str, options).alias("value_parsed"))
df_csv.show()


+--------------+
|  value_parsed|
+--------------+
|[1, a:3, 2001]|
|  [2, b, 2002]|
|  [3, c, 2003]|
+--------------+

Then we can easily flatten the df to put the values in columns:

df2 = df_csv.select("value_parsed.*").toDF("col1","col2","col3")
df2.show()


+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1| a:3|2001|
|   2|   b|2002|
|   3|   c|2003|
+----+----+----+

No breaks. Data correctly parsed. Life is good. Have a beer.

Instead of Column.getItem(i) we can use Column[i].
Also, enumerate is useful in big dataframes.

from pyspark.sql import functions as F
  • Keep parent column:

    for i, c in enumerate(['new_1', 'new_2']):
    df = df.withColumn(c, F.split('my_str_col', '-')[i])
    

    or

    new_cols = ['new_1', 'new_2']
    df = df.select('*', *[F.split('my_str_col', '-')[i].alias(c) for i, c in enumerate(new_cols)])
    
  • Replace parent column:

    for i, c in enumerate(['new_1', 'new_2']):
    df = df.withColumn(c, F.split('my_str_col', '-')[i])
    df = df.drop('my_str_col')
    

    or

    new_cols = ['new_1', 'new_2']
    df = df.select(
    *[c for c in df.columns if c != 'my_str_col'],
    *[F.split('my_str_col', '-')[i].alias(c) for i, c in enumerate(new_cols)]
    )