Generic iterator over dataframe (Spark/scala)

2019-04-14 23:48发布

问题:

I need to iterate over data frame in specific order and apply some complex logic to calculate new column.

In below example I'll be using simple expression where current value for s is multiplication of all previous values thus it may seem like this can be done using UDF or even analytic functions. However, in reality logic is much more complex.

Below code does what is needed

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.encoders.RowEncoder

val q = """
select 10 x, 1 y
union all select 10, 2
union all select 10, 3
union all select 20, 6
union all select 20, 4
union all select 20, 5
"""
val df = spark.sql(q)
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
  iter.scanLeft(Row(0,0,1)) {
    case (r1, r2) => {
      val (x1, y1, s1) = r1 match {case Row(x: Int, y: Int, s: Int) => (x, y, s)}
      val (x2, y2)     = r2 match {case Row(x: Int, y: Int) => (x, y)}
      Row(x2, y2, s1 * y2)
    }
  }.drop(1)
}
val schema = new StructType().
             add(StructField("x", IntegerType, true)).
             add(StructField("y", IntegerType, true)).
             add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

Output

scala> df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+---+
|  x|  y|  s|
+---+---+---+
| 20|  4|  4|
| 20|  5| 20|
| 20|  6|120|
| 10|  1|  1|
| 10|  2|  2|
| 10|  3|  6|
+---+---+---+

What I do not like about it is

1) I explicitly define schema even though Spark can infer names and types for data frame

scala> df
res1: org.apache.spark.sql.DataFrame = [x: int, y: int]

2) If I add any new column to data frame then I have to declare schema again and what is more annoying - re-define function!

Assume there is new column z in data frame. In this case I have to change almost every line in f_row.

def f_row(iter: Iterator[Row]) : Iterator[Row] = {
  iter.scanLeft(Row(0,0,"",1)) {
    case (r1, r2) => {
      val (x1, y1, z1, s1) = r1 match {case Row(x: Int, y: Int, z: String, s: Int) => (x, y, z, s)}
      val (x2, y2, z2)     = r2 match {case Row(x: Int, y: Int, z: String) => (x, y, z)}
      Row(x2, y2, z2, s1 * y2)
    }
  }.drop(1)
}
val schema = new StructType().
             add(StructField("x", IntegerType, true)).
             add(StructField("y", IntegerType, true)).
             add(StructField("z", StringType, true)).
             add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

Output

scala> df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+-----+---+
|  x|  y|    z|  s|
+---+---+-----+---+
| 20|  4|dummy|  4|
| 20|  5|dummy| 20|
| 20|  6|dummy|120|
| 10|  1|dummy|  1|
| 10|  2|dummy|  2|
| 10|  3|dummy|  6|
+---+---+-----+---+

Is there a way to implement logic in more generic way so I do not need to create function to iterate over every specific data frame? Or at least to avoid code changes after adding new columns into data frame which are not used in calculation logic.

Please see updated question below.

Update

Below are two options to iterate in more generic way but still with some drawbacks.

// option 1
def f_row(iter: Iterator[Row]): Iterator[Row] = {
  val r = Row.fromSeq(Row(0, 0).toSeq :+ 1)
  iter.scanLeft(r)((r1, r2) => 
    Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y")))
  ).drop(1)
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

// option 2
def f_row(iter: Iterator[Row]): Iterator[Row] = {
  iter.map{
    var s = 1
    r => {
      s = s * r.getInt(r.fieldIndex("y"))
      Row.fromSeq(r.toSeq :+ s)
    }
  }
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

If a new column added to data frame then initial value for iter.scanLeft has to be changed in Option 1. Also I do not really like Option 2 because it uses mutable var.

Is there a way to improve the code so it's purely functional and no changes are needed when new column added to the data frame?

回答1:

Well, sufficient solution is below

def f_row(iter: Iterator[Row]): Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r = Row.fromSeq(head.toSeq :+ head.getInt(head.fieldIndex("y")))
    iter.scanLeft(r)((r1, r2) => 
      Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y"))))
  } else iter
}
val encoder = 
  RowEncoder(StructType(df.schema.fields :+ StructField("s", IntegerType, false)))
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

Update

Functions like getInt can be avoided in favor of more generic getAs.

Also, in order to be able to access rows of r1 by name we can generate GenericRowWithSchema which is subclass of Row.

Implicit parameter has been added to f_row so that function can use current schema of the data frame and in the same time it can be used as a parameter of the mapPartitions.

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.encoders.RowEncoder

implicit val schema = StructType(df.schema.fields :+ StructField("result", IntegerType))
implicit val encoder = RowEncoder(schema)

def mul(x1: Int, x2: Int) = x1 * x2;

def f_row(iter: Iterator[Row])(implicit currentSchema : StructType) : Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r =
      new GenericRowWithSchema((head.toSeq :+ (head.getAs("y"))).toArray, currentSchema)

    iter.scanLeft(r)((r1, r2) =>
      new GenericRowWithSchema((r2.toSeq :+ mul(r1.getAs("result"), r2.getAs("y"))).toArray, currentSchema))
  } else iter
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show

Finally, logic can be implemented in a tail recursive manner.

import scala.annotation.tailrec

def f_row(iter: Iterator[Row]) = {
  @tailrec
  def f_row_(iter: Iterator[Row], tmp: Int, result: Iterator[Row]): Iterator[Row] = {
    if (iter.hasNext) {
      val r = iter.next
      f_row_(iter, mul(tmp, r.getAs("y")),
        result ++ Iterator(Row.fromSeq(r.toSeq :+ mul(tmp, r.getAs("y")))))
    } else result
  }
  f_row_(iter, 1, Iterator[Row]())
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show