Median / quantiles within PySpark groupBy

2019-03-18 06:06发布

问题:

I would like to calculate group quantiles on a Spark dataframe (using PySpark). Either an approximate or exact result would be fine. I prefer a solution that I can use within the context of groupBy / agg, so that I can mix it with other PySpark aggregate functions. If this is not possible for some reason, a different approach would be fine as well.

This question is related but does not indicate how to use approxQuantile as an aggregate function.

I also have access to the percentile_approx Hive UDF but I don't know how to use it as an aggregate function.

For the sake of specificity, suppose I have the following dataframe:

from pyspark import SparkContext
import pyspark.sql.functions as f

sc = SparkContext()    

df = sc.parallelize([
    ['A', 1],
    ['A', 2],
    ['A', 3],
    ['B', 4],
    ['B', 5],
    ['B', 6],
]).toDF(('grp', 'val'))

df_grp = df.groupBy('grp').agg(f.magic_percentile('val', 0.5).alias('med_val'))
df_grp.show()

Expected result is:

+----+-------+
| grp|med_val|
+----+-------+
|   A|      2|
|   B|      5|
+----+-------+

回答1:

I guess you don't need it anymore. But will leave it here for future generations (i.e. me next week when I forget).

from pyspark.sql import Window
import pyspark.sql.functions as F

grp_window = Window.partitionBy('grp')
magic_percentile = F.expr('percentile_approx(val, 0.5)')

df.withColumn('med_val', magic_percentile.over(grp_window))

Or to address exactly your question, this also works:

df.groupBy('gpr').agg(magic_percentile.alias('med_val'))

And as a bonus, you can pass an array of percentiles:

quantiles = F.expr('percentile_approx(val, array(0.25, 0.5, 0.75))')

And you'll get a list in return.



回答2:

Since you have access to percentile_approx, one simple solution would be to use it in a SQL command:

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

df.registerTempTable("df")
df2 = sqlContext.sql("select grp, percentile_approx(val, 0.5) as med_val from df group by grp")


回答3:

Unfortunately, and to the best of my knowledge, it seems that it is not possible to do this with "pure" PySpark commands (the solution by Shaido provides a workaround with SQL), and the reason is very elementary: in contrast with other aggregate functions, such as mean, approxQuantile does not return a Column type, but a list.

Let's see a quick example with your sample data:

spark.version
# u'2.2.0'

import pyspark.sql.functions as func
from pyspark.sql import DataFrameStatFunctions as statFunc

# aggregate with mean works OK:
df_grp_mean = df.groupBy('grp').agg(func.mean(df['val']).alias('mean_val'))
df_grp_mean.show()
# +---+--------+ 
# |grp|mean_val|
# +---+--------+
# |  B|     5.0|
# |  A|     2.0|
# +---+--------+

# try aggregating by median:
df_grp_med = df.groupBy('grp').agg(statFunc(df).approxQuantile('val', [0.5], 0.1))
# AssertionError: all exprs should be Column

# mean aggregation is a Column, but median is a list:

type(func.mean(df['val']))
# pyspark.sql.column.Column

type(statFunc(df).approxQuantile('val', [0.5], 0.1))
# list

I doubt that a window-based approach will make any difference, since as I said the underlying reason is a very elementary one.

See also my answer here for some more details.