I'm trying to write an UDF which returns a complex type:
private val toPrice = UDF1<String, Map<String, String>> { s ->
val elements = s.split(" ")
mapOf("value" to elements[0], "currency" to elements[1])
}
val type = DataTypes.createStructType(listOf(
DataTypes.createStructField("value", DataTypes.StringType, false),
DataTypes.createStructField("currency", DataTypes.StringType, false)))
df.sqlContext().udf().register("toPrice", toPrice, type)
but any time I use this:
df = df.withColumn("price", callUDF("toPrice", col("price")))
I get a cryptic error:
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$28: (string) => struct<value:string,currency:string>)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:109)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: scala.MatchError: {value=138.0, currency=USD} (of class java.util.LinkedHashMap)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:236)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:231)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:379)
... 19 more
I tried to use a custom data type:
class Price(val value: Double, val currency: String) : Serializable
with an UDF which returns that type:
private val toPrice = UDF1<String, Price> { s ->
val elements = s.split(" ")
Price(elements[0].toDouble(), elements[1])
}
but then I get another MatchError
which complains for the Price
type.
How do I properly write an UDF which can return a complex type?
TL;DR The function should return an object of class org.apache.spark.sql.Row
.
Spark provides two main variants of UDF
definitions.
udf
variants using Scala reflection:
def udf[RT](f: () ⇒ RT)(implicit arg0: TypeTag[RT]): UserDefinedFunction
def udf[RT, A1](f: (A1) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1]): UserDefinedFunction
- ...
def udf[RT, A1, A2, ..., A10](f: (A1, A2, ..., A10) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1], arg2: TypeTag[A2], ..., arg10: TypeTag[A10])
which define
Scala closure of ... arguments as user-defined function (UDF). The data types are automatically inferred based on the Scala closure's signature.
These variants are used without schema with atomics or algebraic data types. For example the function in question would be defined in Scala:
case class Price(value: Double, currency: String)
val df = Seq("1 USD").toDF("price")
val toPrice = udf((s: String) => scala.util.Try {
s split(" ") match {
case Array(price, currency) => Price(price.toDouble, currency)
}
}.toOption)
df.select(toPrice($"price")).show
// +----------+
// |UDF(price)|
// +----------+
// |[1.0, USD]|
// +----------+
In this variant return type is automatically encoded.
Due to it's dependence on reflection this variant is intended primarily for Scala users.
udf
variants providing schema definition (one you use here). The return type for this variant, should be the same as for Dataset[Row]
:
As pointed out in the other answer you can use only the types listed in the SQL types mapping table (atomic types either boxed or unboxed, java.sql.Timestamp
/ java.sql.Date
, as well as high level collections).
Complex structures (structs
/ StructTypes
) are expressed using org.apache.spark.sql.Row
. No mixing with algebraic data types or equivalent is allowed. For example (Scala code)
struct<_1:int,_2:struct<_1:string,_2:struct<_1:double,_2:int>>>
should be expressed as
Row(1, Row("foo", Row(-1.0, 42))))
not
(1, ("foo", (-1.0, 42))))
or any mixed variant, like
Row(1, Row("foo", (-1.0, 42))))
This variant is provided primarily to ensure Java interoperability.
In this case (equivalent to the one in question) the definition should be similar to the following one:
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row
val schema = StructType(Seq(
StructField("value", DoubleType, false),
StructField("currency", StringType, false)
))
val toPrice = udf((s: String) => scala.util.Try {
s split(" ") match {
case Array(price, currency) => Row(price.toDouble, currency)
}
}.getOrElse(null), schema)
df.select(toPrice($"price")).show
// +----------+
// |UDF(price)|
// +----------+
// |[1.0, USD]|
// | null|
// +----------+
Excluding all the nuances of exception handling (in general UDFs
should contr ol for null
input and by convention gracefully handle malformed data) Java equivalent should look more or less like this:
UserDefinedFunction price = udf((String s) -> {
String[] split = s.split(" ");
return RowFactory.create(Double.parseDouble(split[0]), split[1]);
}, DataTypes.createStructType(new StructField[]{
DataTypes.createStructField("value", DataTypes.DoubleType, true),
DataTypes.createStructField("currency", DataTypes.StringType, true)
}));
Context:
To give you some context this distinction is reflected in the other parts of the API as well. For example, you can create DataFrame
from a schema and a sequence of Rows
:
def createDataFrame(rows: List[Row], schema: StructType): DataFrame
or using reflection with a sequence of Products
def createDataFrame[A <: Product](data: Seq[A])(implicit arg0: TypeTag[A]): DataFrame
but no mixed variants are supported.
In other words you should provide input that can be encoded using RowEncoder
.
Of course you wouldn't normally use udf
for the task like this one:
import org.apache.spark.sql.functions._
df.withColumn("price", struct(
split($"price", " ")(0).cast("double").alias("price"),
split($"price", " ")(1).alias("currency")
))
Related:
- Creating a SparkSQL UDF in Java outside of SQLContext
It simple. Go to Data Types Reference and find the corresponding type.
In Spark 2.3
- If you declare return type as
StructType
the functions has to return org.apache.spark.sql.Row
.
- If you return
Map<String, String>
function return type should be MapType
- clearly not what you want.