可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
Let's say I have a DataFrame
with a column for users and another column for words they've written:
Row(user='Bob', word='hello')
Row(user='Bob', word='world')
Row(user='Mary', word='Have')
Row(user='Mary', word='a')
Row(user='Mary', word='nice')
Row(user='Mary', word='day')
I would like to aggregate the word
column into a vector:
Row(user='Bob', words=['hello','world'])
Row(user='Mary', words=['Have','a','nice','day'])
It seems I can't use any of Sparks grouping functions because they expect a subsequent aggregation step. My use case is that I want to feed these data into Word2Vec
not use other Spark aggregations.
回答1:
As of the spark 2.3 release we now have Pandas UDF(aka Vectorized UDF). The function below will accomplish the OP's task... A benefit of using this function is the order is guaranteed to be preserved. Order is essential in many cases such as time series analysis.
import pandas as pd
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, ArrayType
spark = SparkSession.builder.appName('test_collect_array_grouped').getOrCreate()
def collect_array_grouped(df, groupbyCols, aggregateCol, outputCol):
"""
Aggregate function: returns a new :class:`DataFrame` such that for a given column, aggregateCol,
in a DataFrame, df, collect into an array the elements for each grouping defined by the groupbyCols list.
The new DataFrame will have, for each row, the grouping columns and an array of the grouped
values from aggregateCol in the outputCol.
:param groupbyCols: list of columns to group by.
Each element should be a column name (string) or an expression (:class:`Column`).
:param aggregateCol: the column name of the column of values to aggregate into an array
for each grouping.
:param outputCol: the column name of the column to output the aggregeted array to.
"""
groupbyCols = [] if groupbyCols is None else groupbyCols
df = df.select(groupbyCols + [aggregateCol])
schema = df.select(groupbyCols).schema
aggSchema = df.select(aggregateCol).schema
arrayField = StructField(name=outputCol, dataType=ArrayType(aggSchema[0].dataType, False))
schema = schema.add(arrayField)
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def _get_array(pd_df):
vals = pd_df[groupbyCols].iloc[0].tolist()
vals.append(pd_df[aggregateCol].values)
return pd.DataFrame([vals])
return df.groupby(groupbyCols).apply(_get_array)
rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
Row(user='Bob', word='world'),
Row(user='Mary', word='Have'),
Row(user='Mary', word='a'),
Row(user='Mary', word='nice'),
Row(user='Mary', word='day')])
df = spark.createDataFrame(rdd)
collect_array_grouped(df, ['user'], 'word', 'users_words').show()
+----+--------------------+
|user| users_words|
+----+--------------------+
|Mary|[Have, a, nice, day]|
| Bob| [hello, world]|
+----+--------------------+
回答2:
Thanks to @titipat for giving the RDD solution. I did realize shortly after my post that there is actually a DataFrame solution using collect_set
(or collect_list
):
from pyspark.sql import Row
from pyspark.sql.functions import collect_set
rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
Row(user='Bob', word='world'),
Row(user='Mary', word='Have'),
Row(user='Mary', word='a'),
Row(user='Mary', word='nice'),
Row(user='Mary', word='day')])
df = spark.createDataFrame(rdd)
group_user = df.groupBy('user').agg(collect_set('word').alias('words'))
print(group_user.collect())
>[Row(user='Mary', words=['Have', 'nice', 'day', 'a']), Row(user='Bob', words=['world', 'hello'])]
回答3:
from pyspark.sql import functions as F
df.groupby("user").agg(F.collect_list("word"))
回答4:
Here is a solution using rdd
.
from pyspark.sql import Row
rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
Row(user='Bob', word='world'),
Row(user='Mary', word='Have'),
Row(user='Mary', word='a'),
Row(user='Mary', word='nice'),
Row(user='Mary', word='day')])
group_user = rdd.groupBy(lambda x: x.user)
group_agg = group_user.map(lambda x: Row(**{'user': x[0], 'word': [t.word for t in x[1]]}))
Output from group_agg.collect()
:
[Row(user='Bob', word=['hello', 'world']),
Row(user='Mary', word=['Have', 'a', 'nice', 'day'])]
回答5:
You have a native aggregate function for that, collect_set (docs here).
Then, you could use:
from pyspark.sql import functions as F
df.groupby("user").agg(F.collect_set("word"))