define function of numeric types SparkSQL scala

2019-09-10 23:17发布

问题:

I have defined the following function to register as UDF SparkSQL:

def array_sum(x: WrappedArray[Long]): Long= {
    x.sum
}

I would like that this function works with any numeric type that receives as argument. I tried the following:

import Numeric.Implicits._ 
import scala.reflect.ClassTag

def array_sum(x: WrappedArray[NumericType]) = {
   x.sum
}

But it does not work. Any ideas? Thank you!

回答1:

NumericType is Spark SQL specific and is never exposed to UDFs which receive standard Scala objects. So most likely you want something like this:

def array_sum[T : Numeric : ClassTag](x: Seq[T]) = x.sum
udf[Double, Seq[Double]](array_sum _)

although it doesn't look like there is much to gain here. To build something like this the right way you should probably implement custom expression.

Example usage:

val rddDouble: RDD[(Long, Array[Double])] = sc.parallelize(Seq(1L, Array(1.0, 2.0)
val double_array_sum = udf[Double, Seq[Double]](array_sum _)
rddDouble.toDF("k", "v").select(double_array_sum($"v")).show

// +------+
// |UDF(v)|
// +------+
// |   3.0|
// +------+

val rddFloat: RDD[(Long, Array[Float])] = sc.parallelize(Seq(
  (1L, Array(1.0f, 2.0f))
))
val float_array_sum = udf[Float, Seq[Float]](array_sum _)
rddFloat.toDF("k", "v").select(float_array_sum($"v")).show

// +------+
// |UDF(v)|
// +------+
// |   3.0|
// +------+