how to calculate max value in some columns per row

2020-02-14 03:43发布

问题:

I have a data frame read with sqlContext.sql function in pyspark. This contains 4 numerics columns with information per client (this is the key id). I need to calculate the max value per client and join this value to the data frame:

+--------+-------+-------+-------+-------+
|ClientId|m_ant21|m_ant22|m_ant23|m_ant24|
+--------+-------+-------+-------+-------+
|       0|   null|   null|   null|   null|
|       1|   null|   null|   null|   null|
|       2|   null|   null|   null|   null|
|       3|   null|   null|   null|   null|
|       4|   null|   null|   null|   null|
|       5|   null|   null|   null|   null|
|       6|     23|     13|     17|      8|
|       7|   null|   null|   null|   null|
|       8|   null|   null|   null|   null|
|       9|   null|   null|   null|   null|
|      10|     34|      2|      4|      0|
|      11|      0|      0|      0|      0|
|      12|      0|      0|      0|      0|
|      13|      0|      0|     30|      0|
|      14|   null|   null|   null|   null|
|      15|   null|   null|   null|   null|
|      16|     37|     29|     29|     29|
|      17|      0|      0|     16|      0|
|      18|      0|      0|      0|      0|
|      19|   null|   null|   null|   null|
+--------+-------+-------+-------+-------+

In this case, the max value to the client "six" is 23 and the client "ten" is 30. the "null" is naturally null in the new column.

Please help me showing how can i do this operation.

回答1:

I think combing values to a list and than finding max on it would be the simplest approach.

from pyspark.sql.types import *

schema = StructType([
    StructField("ClientId", IntegerType(), True),
    StructField("m_ant21", IntegerType(), True),
    StructField("m_ant22", IntegerType(), True),
    StructField("m_ant23", IntegerType(), True),
    StructField("m_ant24", IntegerType(), True)
])

df = spark\
    .createDataFrame(
        data=[(0, None, None, None, None),
             (1, 23, 13, 17, 99),
             (2, 0, 0, 0, 1),
             (3, 0, None, 1, 0)],
        schema=schema)

import pyspark.sql.functions as F

def agg_to_list(m21,m22,m23,m24):
    return [m21,m22,m23,m24]

u_agg_to_list = F.udf(agg_to_list, ArrayType(IntegerType()))

df2 = df.withColumn('all_values', u_agg_to_list('m_ant21', 'm_ant22', 'm_ant23', 'm_ant24'))\
        .withColumn('max', F.sort_array("all_values", False)[0])\
        .select('ClientId', 'max')

df2.show()

Outputs :

+--------+----+
|ClientId|max |
+--------+----+
|0       |null|
|1       |99  |
|2       |1   |
|3       |1   |
+--------+----+


回答2:

There is a function for that: pyspark.sql.functions.greatest.

>>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
>>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()
[Row(greatest=4)]

The example was taken directly from the docs.

(Least does the opposite.)