How to write Pyspark UDAF on multiple columns?

2019-02-14 15:03发布

I have the following data in a pyspark dataframe called end_stats_df:

values     start    end    cat1   cat2
10          1        2     A      B
11          1        2     C      B
12          1        2      D     B
510          1        2     D      C
550          1        2     C      B
500          1        2     A      B
80          1        3     A      B

And I want to aggregate it in the following way:

  • I want to use the "start" and "end" columns as the aggregate keys
  • For each group of rows, I need to do the following:
    • compute the unique number of values in both cat1 and cat2 for that group. e.g., for the group of start=1 and end=2, this number would be 4 because there's A, B, C, D. This number will be stored as n (n=4 in this example).
    • for the values field, for each group I need to sort the values, and then select every n-1 value, where n is the value stored from the first operation above.
    • at the end of the aggregation, I don't really care what is in cat1 and cat2 after the operations above.

An example output from the example above is:

values     start    end    cat1   cat2
12          1        2      D     B
550          1        2     C      B
80          1        3     A      B

How do I accomplish using pyspark dataframes? I assume I need to use a custom UDAF, right?

1条回答
我想做一个坏孩纸
2楼-- · 2019-02-14 15:31

Pyspark do not support UDAF directly, so we have to do aggregation manually.

from pyspark.sql import functions as f

def func(values, cat1, cat2):
    n = len(set(cat1 + cat2))
    return sorted(values)[n - 2]


df = spark.read.load('file:///home/zht/PycharmProjects/test/text_file.txt', format='csv', sep='\t', header=True)
df = df.groupBy(df['start'], df['end']).agg(f.collect_list(df['values']).alias('values'),
                                            f.collect_set(df['cat1']).alias('cat1'),
                                            f.collect_set(df['cat2']).alias('cat2'))
df = df.select(df['start'], df['end'], f.UserDefinedFunction(func, StringType())(df['values'], df['cat1'], df['cat2']))
查看更多
登录 后发表回答