Normally all rows in a group are passed to an aggregate function. I would like to filter rows using a condition so that only some rows within a group are passed to an aggregate function. Such operation is possible with PostgreSQL. I would like to do the same thing with Spark SQL DataFrame (Spark 2.0.0).
The code could probably look like this:
val df = ... // some data frame
df.groupBy("A").agg(
max("B").where("B").less(10), // there is no such method as `where` :(
max("C").where("C").less(5)
)
So for a data frame like this:
| A | B | C |
| 1| 14| 4|
| 1| 9| 3|
| 2| 5| 6|
The result would be:
|A|max(B)|max(C)|
|1| 9| 4|
|2| 5| null|
Is it possible with Spark SQL?
Note that in general any other aggregate function than max
could be used and there could be multiple aggregates over the same column with arbitrary filtering conditions.
val df = Seq(
(1,14,4),
(1,9,3),
(2,5,6)
).toDF("a","b","c")
val aggregatedDF = df.groupBy("a")
.agg(
max(when($"b" < 10, $"b")).as("MaxB"),
max(when($"c" < 5, $"c")).as("MaxC")
)
aggregatedDF.show
>>> df = sc.parallelize([[1,14,1],[1,9,3],[2,5,6]]).map(lambda t: Row(a=int(t[0]),b=int(t[1]),c=int(t[2]))).toDF()
>>> df.registerTempTable('t')
>>> res = sqlContext.sql("select a,max(case when b<10 then b else null end) mb,max(case when c<5 then c else null end) mc from t group by a")
+---+---+----+
| a| mb| mc|
+---+---+----+
| 1| 9| 3|
| 2| 5|null|
+---+---+----+
You can use sql (I believe you do the same thing in Postgres?)
df.groupBy("name","age","id").agg(functions.max("age").$less(20),functions.max("id").$less("30")).show();
Sample Data:
name age id
abc 23 1001
cde 24 1002
efg 22 1003
ghi 21 1004
ijk 20 1005
klm 19 1006
mno 18 1007
pqr 18 1008
rst 26 1009
tuv 27 1010
pqr 18 1012
rst 28 1013
tuv 29 1011
abc 24 1015
Output:
+----+---+----+---------------+--------------+
|name|age| id|(max(age) < 20)|(max(id) < 30)|
+----+---+----+---------------+--------------+
| rst| 26|1009| false| true|
| abc| 23|1001| false| true|
| ijk| 20|1005| false| true|
| tuv| 29|1011| false| true|
| efg| 22|1003| false| true|
| mno| 18|1007| true| true|
| tuv| 27|1010| false| true|
| klm| 19|1006| true| true|
| cde| 24|1002| false| true|
| pqr| 18|1008| true| true|
| abc| 24|1015| false| true|
| ghi| 21|1004| false| true|
| rst| 28|1013| false| true|
| pqr| 18|1012| true| true|
+----+---+----+---------------+--------------+