The following Spark code correctly demonstrates what I want to do and generates the correct output with a tiny demo data set.
When I run this same general type of code on a large volume of production data, I am having runtime problems. The Spark job runs on my cluster for ~12 hours and fails out.
Just glancing at the code below, it seems inefficient to explode every row, just to merge it back down. In the given test data set, the fourth row with three values in array_value_1 and three values in array_value_2, that will explode to 3*3 or nine exploded rows.
So, in a larger data set, a row with five such array columns, and ten values in each column, would explode out to 10^5 exploded rows?
Looking at the provided Spark functions, there are no out of the box functions that would do what I want. I could supply a user-defined-function. Are there any speed drawbacks to that?
val sparkSession = SparkSession.builder.
master("local")
.appName("merge list test")
.getOrCreate()
val schema = StructType(
StructField("category", IntegerType) ::
StructField("array_value_1", ArrayType(StringType)) ::
StructField("array_value_2", ArrayType(StringType)) ::
Nil)
val rows = List(
Row(1, List("a", "b"), List("u", "v")),
Row(1, List("b", "c"), List("v", "w")),
Row(2, List("c", "d"), List("w")),
Row(2, List("c", "d", "e"), List("x", "y", "z"))
)
val df = sparkSession.createDataFrame(rows.asJava, schema)
val dfExploded = df.
withColumn("scalar_1", explode(col("array_value_1"))).
withColumn("scalar_2", explode(col("array_value_2")))
// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")
val dfOutput = dfExploded.groupBy("category").agg(
collect_set("scalar_1").alias("combined_values_2"),
collect_set("scalar_2").alias("combined_values_2"))
dfOutput.show()