Spark UDAF - using generics as input type?

2019-02-15 15:40发布

I want to write Spark UDAF where type of the column could be any that has a Scala Numeric defined on it. I've searched over Internet but found only examples with concrete types like DoubleType, LongType. Isn't this possible? But how then use that UDAFs with other numeric values?

1条回答
兄弟一词,经得起流年.
2楼-- · 2019-02-15 16:08

For simplicity let's assume you want to define a custom sum. You'll have provide a TypeTag for the input type and use Scala reflection to define schemas:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor

case class MySum [T : TypeTag](implicit n: Numeric[T]) 
    extends UserDefinedAggregateFunction {

  val dt = schemaFor[T].dataType
  def inputSchema = new StructType().add("x", dt)
  def bufferSchema = new StructType().add("x", dt)

  def dataType = dt
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = buffer.update(0,  n.zero)
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0))
      buffer.update(0, n.plus(buffer.getAs[T](0), input.getAs[T](0)))
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1.update(0, n.plus(buffer1.getAs[T](0),  buffer2.getAs[T](0)))    
  }

  def evaluate(buffer: Row) = buffer.getAs[T](0)
}

With a function defined as above we can create instance handling specific types:

val sumOfLong = MySum[Long]
spark.range(10).select(sumOfLong($"id")).show
+---------+
|mysum(id)|
+---------+
|       45|
+---------+

Note:

To get the same flexibility as the built-in aggregate functions you'd have to define your own AggregateFunction, like ImperativeAggregate or DeclarativeAggregate. It is possible, but it is an internal API.

查看更多
登录 后发表回答