Register UDF to SqlContext from Scala to use in Py

2019-01-25 21:35发布

问题:

Is it possible to register a UDF (or function) written in Scala to use in PySpark ? E.g.:

val mytable = sc.parallelize(1 to 2).toDF("spam")
mytable.registerTempTable("mytable")
def addOne(m: Integer): Integer = m + 1
// Spam: 1, 2

In Scala, the following is now possible:

val UDFaddOne = sqlContext.udf.register("UDFaddOne", addOne _)
val mybiggertable = mytable.withColumn("moreSpam", UDFaddOne(mytable("spam")))
// Spam: 1, 2
// moreSpam: 2, 3

I would like to use "UDFaddOne" in PySpark like

%pyspark

mytable = sqlContext.table("mytable")
UDFaddOne = sqlContext.udf("UDFaddOne") # does not work
mybiggertable = mytable.withColumn("+1", UDFaddOne(mytable("spam"))) # does not work

Background: We are a team of developpers, some coding in Scala and some in Python, and would like to share already written functions. It would also be possible to save it into a library and import it.

回答1:

As far as I know PySpark doesn't provide any equivalent of the callUDF function and because of that it is not possible to access registered UDF directly.

The simplest solution here is to use raw SQL expression:

mytable.withColumn("moreSpam", expr("UDFaddOne({})".format("spam")))

## OR
sqlContext.sql("SELECT *, UDFaddOne(spam) AS moreSpam FROM mytable")

## OR
mytable.selectExpr("*", "UDFaddOne(spam) AS moreSpam")

This approach is rather limited so if you need to support more complex workflows you should build a package and provide complete Python wrappers. You'll find and example UDAF wrapper in my answer to Spark: How to map Python with Scala or Java User Defined Functions?



回答2:

The following worked for me (basically a summary of multiple places including the link provided by zero323):

In scala:

package com.example
import org.apache.spark.sql.functions.udf

object udfObj extends Serializable {
  def createUDF = {
    udf((x: Int) => x + 1)
  }
}

in python (assume sc is the spark context. If you are using spark 2.0 you can get it from the spark session):

from py4j.java_gateway import java_import
from pyspark.sql.column import Column

jvm = sc._gateway.jvm
java_import(jvm, "com.example")
def udf_f(col):
    return Column(jvm.com.example.udfObj.createUDF().apply(col))

And of course make sure the jar created in scala is added using --jars and --driver-class-path

So what happens here:

We create a function inside a serializable object which returns the udf in scala (I am not 100% sure Serializable is required, it was required for me for more complex UDF so it could be because it needed to pass java objects).

In python we use access the internal jvm (this is a private member so it could be changed in the future but I see no way around it) and import our package using java_import. We access the createUDF function and call it. This creates an object which has the apply method (functions in scala are actually java objects with the apply method). The input to the apply method is a column. The result of applying the column is a new column so we need to wrap it with the Column method to make it available to withColumn.