Consider the following DataFrame:
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[john, sam, jane] |
|pet |[whiskers, rover, fido]|
+------+-----------------------+
Which can be created with the following code:
import pyspark.sql.functions as f
data = [
('person', ['john', 'sam', 'jane']),
('pet', ['whiskers', 'rover', 'fido'])
]
df = sqlCtx.createDataFrame(data, ["type", "names"])
df.show(truncate=False)
Is there a way to directly modify the ArrayType()
column "names"
by applying a function to each element, without using a udf
?
For example, suppose I wanted to apply the function foo
to the "names"
column. (I will use the example where foo
is str.upper
just for illustrative purposes, but my question is regarding any valid function that can be applied to the elements of an iterable.)
foo = lambda x: x.upper() # defining it as str.upper as an example
df.withColumn('X', [foo(x) for x in f.col("names")]).show()
TypeError: Column is not iterable
I could do this using a udf
:
foo_udf = f.udf(lambda row: [foo(x) for x in row], ArrayType(StringType()))
df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
#+------+-----------------------+
#|type |names |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE] |
#|pet |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+
In this specific example, I could avoid the udf
by exploding the column, call pyspark.sql.functions.upper()
, and then groupBy
and collect_list
:
df.select('type', f.explode('names').alias('name'))\
.withColumn('name', f.upper(f.col('name')))\
.groupBy('type')\
.agg(f.collect_list('name').alias('names'))\
.show(truncate=False)
#+------+-----------------------+
#|type |names |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE] |
#|pet |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+
But this is a lot of code to do something simple. Is there is a more direct way to iterate over the elements of an ArrayType()
using spark-dataframe functions?
In Spark < 2.4 you can use an user defined function:
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, DataType, StringType
def transform(f, t=StringType()):
if not isinstance(t, DataType):
raise TypeError("Invalid type {}".format(type(t)))
@udf(ArrayType(t))
def _(xs):
if xs is not None:
return [f(x) for x in xs]
return _
foo_udf = transform(str.upper)
df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[JOHN, SAM, JANE] |
|pet |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+
Considering high cost of explode
+ collect_list
idiom, this approach is almost exclusively preferred, despite its intrinsic cost.
In Spark 2.4 or later you can use transform
* with upper
(see SPARK-23909):
from pyspark.sql.functions import expr
df.withColumn(
'names', expr('transform(names, x -> upper(x))')
).show(truncate=False)
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[JOHN, SAM, JANE] |
|pet |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+
It is also possible to use pandas_udf
from pyspark.sql.functions import pandas_udf, PandasUDFType
def transform_pandas(f, t=StringType()):
if not isinstance(t, DataType):
raise TypeError("Invalid type {}".format(type(t)))
@pandas_udf(ArrayType(t), PandasUDFType.SCALAR)
def _(xs):
return xs.apply(lambda xs: [f(x) for x in xs] if xs is not None else xs)
return _
foo_udf_pandas = transform_pandas(str.upper)
df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[JOHN, SAM, JANE] |
|pet |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+
although only the latest Arrow / PySpark combinations support handling ArrayType
columns (SPARK-24259, SPARK-21187). Nonetheless this option should be more efficient than standard UDF (especially with a lower serde overhead) while supporting arbitrary Python functions.
* A number of other higher order functions are also supported, including, but not limited to filter
and aggregate
. See for example
- Querying Spark SQL DataFrame with complex types
- How to slice and sum elements of array column?
- Filter array column content
- Spark Scala row-wise average by handling null.
- How to use transform higher-order function?.
Yes you can do it by converting it to RDD and then back to DF.
>>> df.show(truncate=False)
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[john, sam, jane] |
|pet |[whiskers, rover, fido]|
+------+-----------------------+
>>> df.rdd.mapValues(lambda x: [y.upper() for y in x]).toDF(["type","names"]).show(truncate=False)
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[JOHN, SAM, JANE] |
|pet |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+