Pass a struct to an UDAF in spark

2020-07-28 10:37发布

问题:

I have the following schema -

root
 |-- id:string (nullable = false)
 |-- age: long (nullable = true)
 |-- cars: struct (nullable = true)
 |    |-- car1: string (nullable = true)
 |    |-- car2: string (nullable = true)
 |    |-- car3: string (nullable = true)
 |-- name: string (nullable = true)

How can I pass the struct 'cars' to an udaf? What should be the inputSchema if i just want to pass the cars sub-struct.

回答1:

You could, but the logic of the UDAF would be different. For example, if you have two rows:

val seq = Seq(cars(cars_schema("car1", "car2", "car3")), (cars(cars_schema("car1", "car2", "car3"))))

val rdd = spark.sparkContext.parallelize(seq)

Here the schema is

root
 |-- cars: struct (nullable = true)
 |    |-- car1: string (nullable = true)
 |    |-- car2: string (nullable = true)
 |    |-- car3: string (nullable = true)

then if you try to call the aggregation:

val df = seq.toDF
df.agg(agg0(col("cars")))

You must change your UDAFs input schema like:

val carsSchema =
    StructType(List(StructField("car1", StringType, true), StructField("car2", StringType, true), StructField("car3", StringType, true)))

and in the boy of your UDAF you must deal with this schema changing the inputSchema:

override def inputSchema: StructType = StructType(StructField("input", carsSchema) :: Nil)

In your update method you must deal with the format of your input Rows:

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  val i = input.getAs[Array[Array[String]]](0)
  // i here would be [car1,car2,car3],  an array of strings
  buffer(0) = ???
}

An from here, you can transform i to update your buffer and complete the merge and evaluate functions.