How to create a Spark UDF in Java / Kotlin which r

2019-01-18 17:03发布

问题:

I'm trying to write an UDF which returns a complex type:

private val toPrice = UDF1<String, Map<String, String>> { s ->
    val elements = s.split(" ")
    mapOf("value" to elements[0], "currency" to elements[1])
}


val type = DataTypes.createStructType(listOf(
        DataTypes.createStructField("value", DataTypes.StringType, false),
        DataTypes.createStructField("currency", DataTypes.StringType, false)))
df.sqlContext().udf().register("toPrice", toPrice, type)

but any time I use this:

df = df.withColumn("price", callUDF("toPrice", col("price")))

I get a cryptic error:

Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$28: (string) => struct<value:string,currency:string>)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
Caused by: scala.MatchError: {value=138.0, currency=USD} (of class java.util.LinkedHashMap)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:236)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:231)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:379)
    ... 19 more

I tried to use a custom data type:

class Price(val value: Double, val currency: String) : Serializable

with an UDF which returns that type:

private val toPrice = UDF1<String, Price> { s ->
    val elements = s.split(" ")
    Price(elements[0].toDouble(), elements[1])
}

but then I get another MatchError which complains for the Price type.

How do I properly write an UDF which can return a complex type?

回答1:

TL;DR The function should return an object of class org.apache.spark.sql.Row.

Spark provides two main variants of UDF definitions.

  1. udf variants using Scala reflection:

    • def udf[RT](f: () ⇒ RT)(implicit arg0: TypeTag[RT]): UserDefinedFunction
    • def udf[RT, A1](f: (A1) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1]): UserDefinedFunction
    • ...
    • def udf[RT, A1, A2, ..., A10](f: (A1, A2, ..., A10) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1], arg2: TypeTag[A2], ..., arg10: TypeTag[A10])

    which define

    Scala closure of ... arguments as user-defined function (UDF). The data types are automatically inferred based on the Scala closure's signature.

    These variants are used without schema with atomics or algebraic data types. For example the function in question would be defined in Scala:

    case class Price(value: Double, currency: String) 
    
    val df = Seq("1 USD").toDF("price")
    
    val toPrice = udf((s: String) => scala.util.Try { 
      s split(" ") match {
        case Array(price, currency) => Price(price.toDouble, currency)
      }
    }.toOption)
    
    df.select(toPrice($"price")).show
    // +----------+
    // |UDF(price)|
    // +----------+
    // |[1.0, USD]|
    // +----------+
    

    In this variant return type is automatically encoded.

    Due to it's dependence on reflection this variant is intended primarily for Scala users.

  2. udf variants providing schema definition (one you use here). The return type for this variant, should be the same as for Dataset[Row]:

    • As pointed out in the other answer you can use only the types listed in the SQL types mapping table (atomic types either boxed or unboxed, java.sql.Timestamp / java.sql.Date, as well as high level collections).

    • Complex structures (structs / StructTypes) are expressed using org.apache.spark.sql.Row. No mixing with algebraic data types or equivalent is allowed. For example (Scala code)

      struct<_1:int,_2:struct<_1:string,_2:struct<_1:double,_2:int>>>
      

      should be expressed as

      Row(1, Row("foo", Row(-1.0, 42))))
      

      not

      (1, ("foo", (-1.0, 42))))
      

      or any mixed variant, like

      Row(1, Row("foo", (-1.0, 42))))
      

    This variant is provided primarily to ensure Java interoperability.

    In this case (equivalent to the one in question) the definition should be similar to the following one:

    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions.udf
    import org.apache.spark.sql.Row
    
    
    val schema = StructType(Seq(
      StructField("value", DoubleType, false),
      StructField("currency", StringType, false)
    ))
    
    val toPrice = udf((s: String) => scala.util.Try { 
      s split(" ") match {
        case Array(price, currency) => Row(price.toDouble, currency)
      }
    }.getOrElse(null), schema)
    
    df.select(toPrice($"price")).show
    // +----------+
    // |UDF(price)|
    // +----------+
    // |[1.0, USD]|
    // |      null|
    // +----------+
    

    Excluding all the nuances of exception handling (in general UDFs should contr ol for null input and by convention gracefully handle malformed data) Java equivalent should look more or less like this:

    UserDefinedFunction price = udf((String s) -> {
        String[] split = s.split(" ");
        return RowFactory.create(Double.parseDouble(split[0]), split[1]);
    }, DataTypes.createStructType(new StructField[]{
        DataTypes.createStructField("value", DataTypes.DoubleType, true),
        DataTypes.createStructField("currency", DataTypes.StringType, true)
    }));
    

Context:

To give you some context this distinction is reflected in the other parts of the API as well. For example, you can create DataFrame from a schema and a sequence of Rows:

def createDataFrame(rows: List[Row], schema: StructType): DataFrame 

or using reflection with a sequence of Products

def createDataFrame[A <: Product](data: Seq[A])(implicit arg0: TypeTag[A]): DataFrame 

but no mixed variants are supported.

In other words you should provide input that can be encoded using RowEncoder.

Of course you wouldn't normally use udf for the task like this one:

import org.apache.spark.sql.functions._

df.withColumn("price", struct(
  split($"price", " ")(0).cast("double").alias("price"),
  split($"price", " ")(1).alias("currency")
))

Related:

  • Creating a SparkSQL UDF in Java outside of SQLContext


回答2:

It simple. Go to Data Types Reference and find the corresponding type.

In Spark 2.3

  • If you declare return type as StructType the functions has to return org.apache.spark.sql.Row.
  • If you return Map<String, String> function return type should be MapType - clearly not what you want.