I need to iterate over data frame in specific order and apply some complex logic to calculate new column.
In below example I'll be using simple expression where current value for s
is multiplication of all previous values thus it may seem like this can be done using UDF or even analytic functions. However, in reality logic is much more complex.
Below code does what is needed
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
val q = """
select 10 x, 1 y
union all select 10, 2
union all select 10, 3
union all select 20, 6
union all select 20, 4
union all select 20, 5
"""
val df = spark.sql(q)
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
iter.scanLeft(Row(0,0,1)) {
case (r1, r2) => {
val (x1, y1, s1) = r1 match {case Row(x: Int, y: Int, s: Int) => (x, y, s)}
val (x2, y2) = r2 match {case Row(x: Int, y: Int) => (x, y)}
Row(x2, y2, s1 * y2)
}
}.drop(1)
}
val schema = new StructType().
add(StructField("x", IntegerType, true)).
add(StructField("y", IntegerType, true)).
add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
Output
scala> df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+---+
| x| y| s|
+---+---+---+
| 20| 4| 4|
| 20| 5| 20|
| 20| 6|120|
| 10| 1| 1|
| 10| 2| 2|
| 10| 3| 6|
+---+---+---+
What I do not like about it is
1) I explicitly define schema even though Spark can infer names and types for data frame
scala> df
res1: org.apache.spark.sql.DataFrame = [x: int, y: int]
2) If I add any new column to data frame then I have to declare schema again and what is more annoying - re-define function!
Assume there is new column z
in data frame. In this case I have to change almost every line in f_row
.
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
iter.scanLeft(Row(0,0,"",1)) {
case (r1, r2) => {
val (x1, y1, z1, s1) = r1 match {case Row(x: Int, y: Int, z: String, s: Int) => (x, y, z, s)}
val (x2, y2, z2) = r2 match {case Row(x: Int, y: Int, z: String) => (x, y, z)}
Row(x2, y2, z2, s1 * y2)
}
}.drop(1)
}
val schema = new StructType().
add(StructField("x", IntegerType, true)).
add(StructField("y", IntegerType, true)).
add(StructField("z", StringType, true)).
add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
Output
scala> df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+-----+---+
| x| y| z| s|
+---+---+-----+---+
| 20| 4|dummy| 4|
| 20| 5|dummy| 20|
| 20| 6|dummy|120|
| 10| 1|dummy| 1|
| 10| 2|dummy| 2|
| 10| 3|dummy| 6|
+---+---+-----+---+
Is there a way to implement logic in more generic way so I do not need to create function to iterate over every specific data frame?
Or at least to avoid code changes after adding new columns into data frame which are not used in calculation logic.
Please see updated question below.
Update
Below are two options to iterate in more generic way but still with some drawbacks.
// option 1
def f_row(iter: Iterator[Row]): Iterator[Row] = {
val r = Row.fromSeq(Row(0, 0).toSeq :+ 1)
iter.scanLeft(r)((r1, r2) =>
Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y")))
).drop(1)
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
// option 2
def f_row(iter: Iterator[Row]): Iterator[Row] = {
iter.map{
var s = 1
r => {
s = s * r.getInt(r.fieldIndex("y"))
Row.fromSeq(r.toSeq :+ s)
}
}
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
If a new column added to data frame then initial value for iter.scanLeft has to be changed in Option 1. Also I do not really like Option 2 because it uses mutable var.
Is there a way to improve the code so it's purely functional and no changes are needed when new column added to the data frame?