Spark SQL - Generate array of arrays from the sql

2020-07-18 08:12发布

问题:

I want to create an array of arrays. This is my data table:

// A case class for our sample table
case class Testing(name: String, age: Int, salary: Int)

// Create an RDD with some data
val x = sc.parallelize(Array(
    Testing(null, 21, 905),
    Testing("Noelia", 26, 1130),
    Testing("Pilar", 52,  1890),
    Testing("Roberto", 31, 1450)
 ))

// Convert RDD to a DataFrame 
val df = sqlContext.createDataFrame(x) 

// For SQL usage we need to register the table
df.registerTempTable("df")

I want to create an array of integer column "age". For that I use "collect_list":

sqlContext.sql("SELECT collect_list(age) as age from df").show

But now I want to generate an array containing multiple arrays as created above:

 sqlContext.sql("SELECT collect_list(collect_list(age), collect_list(salary)) as arrayInt from df").show

But this does not work , or use the function org.apache.spark.sql.functions.array. Any ideas?

回答1:

Ok, things can't get more simple. Let's consider the same data you are working on and go step by step from there

// A case class for our sample table
case class Testing(name: String, age: Int, salary: Int)

// Create an RDD with some data
val x = sc.parallelize(Array(
  Testing(null, 21, 905),
  Testing("Noelia", 26, 1130),
  Testing("Pilar", 52, 1890),
  Testing("Roberto", 31, 1450)
))

// Convert RDD to a DataFrame
val df = sqlContext.createDataFrame(x)

// For SQL usage we need to register the table
df.registerTempTable("df")
sqlContext.sql("select collect_list(age) as age from df").show

// +----------------+
// |             age|
// +----------------+
// |[21, 26, 52, 31]|
// +----------------+

sqlContext.sql("select collect_list(collect_list(age),     collect_list(salary)) as arrayInt from df").show

As the error message says :

org.apache.spark.sql.AnalysisException: No handler for Hive udf class
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList because: Exactly one argument is expected..; line 1 pos 52 [...]

collest_list takes just one argument. Let's check the documentation here.

It actually takes one argument ! But let's go further in the documentation of the functions object. You seem to have noticed that the array function allows you to create a new array column out of a Column or a repeated Column parameter. So let's use that :

sqlContext.sql("select array(collect_list(age), collect_list(salary)) as arrayInt from df").show(false)

The array function create indeed a column from the column list create before-hand by collect_list on both age and salary :

// +-------------------------------------------------------------------+
// |arrayInt                                                           |
// +-------------------------------------------------------------------+
// |[WrappedArray(21, 26, 52, 31), WrappedArray(905, 1130, 1890, 1450)]|
// +-------------------------------------------------------------------+

Where do we go from here ?

You have to remember that a Row from a DataFrame is just another collection wrapped by a Row.

The first thing I'll do is work on that collection. So How do we flatten a WrappedArray[WrappedArray[Int]] ?

Scala is kind of magical you just need to use .flatten

import scala.collection.mutable.WrappedArray

val firstRow: mutable.WrappedArray[mutable.WrappedArray[Int]] =
  sqlContext.sql("select array(collect_list(age), collect_list(salary)) as arrayInt from df")
    .first.get(0).asInstanceOf[WrappedArray[WrappedArray[Int]]]
// res26: scala.collection.mutable.WrappedArray[scala.collection.mutable.WrappedArray[Int]] =
// WrappedArray(WrappedArray(21, 26, 52, 31), WrappedArray(905, 1130, 1890, 1450))

firstRow.flatten
// res27: scala.collection.mutable.IndexedSeq[Int] = ArrayBuffer(21, 26, 52, 31, 905, 1130, 1890, 1450)

Now let's wrap it in a UDF so we can use it on the DataFrame :

def flatten(array: WrappedArray[WrappedArray[Int]]) = array.flatten
sqlContext.udf.register("flatten", flatten(_: WrappedArray[WrappedArray[Int]]))

Since we registered the UDF, we can now use it inside the sqlContext :

sqlContext.sql("select flatten(array(collect_list(age), collect_list(salary))) as arrayInt from df").show(false)

// +---------------------------------------+
// |arrayInt                               |
// +---------------------------------------+
// |[21, 26, 52, 31, 905, 1130, 1890, 1450]|
// +---------------------------------------+

I hope this helps !



回答2:

Let's create the DataFrame the way have created above.

// A case class for our sample table
import org.apache.spark.sql.functions._

case class Testing(name: String, age: Int, salary: Int)

// Create an RDD with some data
val x = sc.parallelize(Array(
    Testing(null, 21, 905),
    Testing("Noelia", 26, 1130),
    Testing("Pilar", 52,  1890),
    Testing("Roberto", 31, 1450)
 ))

// Convert RDD to a DataFrame 
val df = spark.createDataFrame(x)

Here we can use array_union function to achieve the desired result. array_unionfunction will return the union of all elements from the input arrays. This function is available since spark 2.4.0

// Scala Ref : https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.functions$

// Pyspark Ref : https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.array_union

df.select(collect_list("age").as("age"), collect_list("salary").as("salary"))
  .withColumn("new_col", array_union($"age", $"salary")).show(truncate=false)

// Output

+----------------+-----------------------+---------------------------------------+
|age             |salary                 |new_col                                |
+----------------+-----------------------+---------------------------------------+
|[21, 26, 52, 31]|[905, 1130, 1890, 1450]|[21, 26, 52, 31, 905, 1130, 1890, 1450]|
+----------------+-----------------------+---------------------------------------+

I hope this helps.