Consider the following DataFrame:
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[john, sam, jane] |
|pet |[whiskers, rover, fido]|
+------+-----------------------+
Which can be created with the following code:
import pyspark.sql.functions as f
data = [
('person', ['john', 'sam', 'jane']),
('pet', ['whiskers', 'rover', 'fido'])
]
df = sqlCtx.createDataFrame(data, ["type", "names"])
df.show(truncate=False)
Is there a way to directly modify the ArrayType()
column "names"
by applying a function to each element, without using a udf
?
For example, suppose I wanted to apply the function foo
to the "names"
column. (I will use the example where foo
is str.upper
just for illustrative purposes, but my question is regarding any valid function that can be applied to the elements of an iterable.)
foo = lambda x: x.upper() # defining it as str.upper as an example
df.withColumn('X', [foo(x) for x in f.col("names")]).show()
TypeError: Column is not iterable
I could do this using a udf
:
foo_udf = f.udf(lambda row: [foo(x) for x in row], ArrayType(StringType()))
df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
#+------+-----------------------+
#|type |names |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE] |
#|pet |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+
In this specific example, I could avoid the udf
by exploding the column, call pyspark.sql.functions.upper()
, and then groupBy
and collect_list
:
df.select('type', f.explode('names').alias('name'))\
.withColumn('name', f.upper(f.col('name')))\
.groupBy('type')\
.agg(f.collect_list('name').alias('names'))\
.show(truncate=False)
#+------+-----------------------+
#|type |names |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE] |
#|pet |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+
But this is a lot of code to do something simple. Is there is a more direct way to iterate over the elements of an ArrayType()
using spark-dataframe functions?
Yes you can do it by converting it to RDD and then back to DF.
In Spark < 2.4 you can use an user defined function:
Considering high cost of
explode
+collect_list
idiom, this approach is almost exclusively preferred, despite its intrinsic cost.In Spark 2.4 or later you can use
transform
* withupper
(see SPARK-23909):It is also possible to use
pandas_udf
although only the latest Arrow / PySpark combinations support handling
ArrayType
columns (SPARK-24259, SPARK-21187). Nonetheless this option should be more efficient than standard UDF (especially with a lower serde overhead) while supporting arbitrary Python functions.* A number of other higher order functions are also supported, including, but not limited to
filter
andaggregate
. See for example