aggregate function Count usage with groupBy in Spa

2020-02-17 01:32发布

问题:

I'm trying to make multiple operations in one line of code in pySpark, and not sure if that's possible for my case.

My intention is not having to save the output as a new dataframe.

My current code is rather simple:

encodeUDF = udf(encode_time, StringType())
new_log_df.cache().withColumn('timePeriod', encodeUDF(col('START_TIME')))
  .groupBy('timePeriod')
  .agg(
    mean('DOWNSTREAM_SIZE').alias("Mean"),
    stddev('DOWNSTREAM_SIZE').alias("Stddev")
  )
  .show(20, False)

And my intention is to add count() after using groupBy, to get, well, the count of records matching each value of timePeriod column, printed\shown as output.

When trying to use groupBy(..).count().agg(..) I get exceptions.

Is there any way to achieve both count() and agg().show() prints, without splitting code to two lines of commands, e.g. :

new_log_df.withColumn(..).groupBy(..).count()
new_log_df.withColumn(..).groupBy(..).agg(..).show()

Or better yet, for getting a merged output to agg.show() output - An extra column which states the counted number of records matching the row's value. e.g.:

timePeriod | Mean | Stddev | Num Of Records
    X      | 10   |   20   |    315

回答1:

count() can be used inside agg() as groupBy expression is same.

With Python

import pyspark.sql.functions as func

new_log_df.cache().withColumn("timePeriod", encodeUDF(new_log_df["START_TIME"])) 
  .groupBy("timePeriod")
  .agg(
     func.mean("DOWNSTREAM_SIZE").alias("Mean"), 
     func.stddev("DOWNSTREAM_SIZE").alias("Stddev"),
     func.count(func.lit(1)).alias("Num Of Records")
   )
  .show(20, False)

pySpark SQL functions doc

With Scala

import org.apache.spark.sql.functions._ //for count()

new_log_df.cache().withColumn("timePeriod", encodeUDF(col("START_TIME"))) 
  .groupBy("timePeriod")
  .agg(
     mean("DOWNSTREAM_SIZE").alias("Mean"), 
     stddev("DOWNSTREAM_SIZE").alias("Stddev"),
     count(lit(1)).alias("Num Of Records")
   )
  .show(20, false)

count(1) will count the records by first column which is equal to count("timePeriod")

With Java

import static org.apache.spark.sql.functions.*;

new_log_df.cache().withColumn("timePeriod", encodeUDF(col("START_TIME"))) 
  .groupBy("timePeriod")
  .agg(
     mean("DOWNSTREAM_SIZE").alias("Mean"), 
     stddev("DOWNSTREAM_SIZE").alias("Stddev"),
     count(lit(1)).alias("Num Of Records")
   )
  .show(20, false)