I have dataframe with two level nested fields
root
|-- request: struct (nullable = true)
| |-- dummyID: string (nullable = true)
| |-- data: struct (nullable = true)
| | |-- fooID: string (nullable = true)
| | |-- barID: string (nullable = true)
I want to update the value of fooId
column here. I was able to update value for the first level for example dummyID
column here using this question as reference How to add a nested column to a DataFrame
Input data:
{
"request": {
"dummyID": "test_id",
"data": {
"fooID": "abc",
"barID": "1485351"
}
}
}
output data:
{
"request": {
"dummyID": "test_id",
"data": {
"fooID": "def",
"barID": "1485351"
}
}
}
How can I do it using Scala?
Here is a generic solution to this problem that makes it possible to update any number of nested values, at any level, based on an arbitrary function applied in a recursive traversal:
def mutate(df: DataFrame, fn: Column => Column): DataFrame = {
// Get a projection with fields mutated by `fn` and select it
// out of the original frame with the schema reassigned to the original
// frame (explained later)
df.sqlContext.createDataFrame(df.select(traverse(df.schema, fn):_*).rdd, df.schema)
}
def traverse(schema: StructType, fn: Column => Column, path: String = ""): Array[Column] = {
schema.fields.map(f => {
f.dataType match {
case s: StructType => struct(traverse(s, fn, path + f.name + "."): _*)
case _ => fn(col(path + f.name))
}
})
}
This is effectively equivalent to the usual "just redefine the whole struct as a projection" solutions, but it automates re-nesting fields with the original structure AND preserves nullability/metadata (which are lost when you redefine the structs manually). Annoyingly, preserving those properties isn't possible while creating the projection (afaict) so the code above redefines the schema manually.
An example application:
case class Organ(name: String, count: Int)
case class Disease(id: Int, name: String, organ: Organ)
case class Drug(id: Int, name: String, alt: Array[String])
val df = Seq(
(1, Drug(1, "drug1", Array("x", "y")), Disease(1, "disease1", Organ("heart", 2))),
(2, Drug(2, "drug2", Array("a")), Disease(2, "disease2", Organ("eye", 3)))
).toDF("id", "drug", "disease")
df.show(false)
+---+------------------+-------------------------+
|id |drug |disease |
+---+------------------+-------------------------+
|1 |[1, drug1, [x, y]]|[1, disease1, [heart, 2]]|
|2 |[2, drug2, [a]] |[2, disease2, [eye, 3]] |
+---+------------------+-------------------------+
// Update the integer field ("count") at the lowest level:
val df2 = mutate(df, c => if (c.toString == "disease.organ.count") c - 1 else c)
df2.show(false)
+---+------------------+-------------------------+
|id |drug |disease |
+---+------------------+-------------------------+
|1 |[1, drug1, [x, y]]|[1, disease1, [heart, 1]]|
|2 |[2, drug2, [a]] |[2, disease2, [eye, 2]] |
+---+------------------+-------------------------+
// This will NOT necessarily be equal unless the metadata and nullability
// of all fields is preserved (as the code above does)
assertResult(df.schema.toString)(df2.schema.toString)
A limitation of this is that it cannot add new fields, only update existing ones (though the map can be changed into a flatMap and the function to return Array[Column] for that, if you don't care about preserving nullability/metadata).
Additionally, here is a more generic version for Dataset[T]:
case class Record(id: Int, drug: Drug, disease: Disease)
def mutateDS[T](df: Dataset[T], fn: Column => Column)(implicit enc: Encoder[T]): Dataset[T] = {
df.sqlContext.createDataFrame(df.select(traverse(df.schema, fn):_*).rdd, enc.schema).as[T]
}
// To call as typed dataset:
val fn: Column => Column = c => if (c.toString == "disease.organ.count") c - 1 else c
mutateDS(df.as[Record], fn).show(false)
// To call as untyped dataset:
implicit val encoder: ExpressionEncoder[Row] = RowEncoder(df.schema) // This is necessary regardless of sparkSession.implicits._ imports
mutateDS(df, fn).show(false)
One way, although cumbersome is to fully unpack and recreate the column by explicitly referencing each element of the original struct.
dataFrame.withColumn("person",
struct(
col("person.age").alias("age),
struct(
col("person.name.first").alias("first"),
lit("some new value").alias("last")).alias("name")))