I have a dataframe with the following structure:
|-- data: struct (nullable = true)
| |-- id: long (nullable = true)
| |-- keyNote: struct (nullable = true)
| | |-- key: string (nullable = true)
| | |-- note: string (nullable = true)
| |-- details: map (nullable = true)
| | |-- key: string
| | |-- value: string (valueContainsNull = true)
How it is possible to flatten the structure and create a new dataframe:
|-- id: long (nullable = true)
|-- keyNote: struct (nullable = true)
| |-- key: string (nullable = true)
| |-- note: string (nullable = true)
|-- details: map (nullable = true)
| |-- key: string
| |-- value: string (valueContainsNull = true)
Is there something like explode, but for structs?
This should work in Spark 1.6 or later:
df.select(df.col("data.*"))
or
df.select(df.col("data.id"), df.col("data.keyNote"), df.col("data.details"))
Here is function that is doing what you want and that can deal with multiple nested columns containing columns with same name:
def flatten_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
flat_df = nested_df.select(flat_cols +
[F.col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols
for c in nested_df.select(nc+'.*').columns])
return flat_df
Before:
root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo: struct (nullable = true)
| |-- a: float (nullable = true)
| |-- b: float (nullable = true)
| |-- c: integer (nullable = true)
|-- bar: struct (nullable = true)
| |-- a: float (nullable = true)
| |-- b: float (nullable = true)
| |-- c: integer (nullable = true)
After:
root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo_a: float (nullable = true)
|-- foo_b: float (nullable = true)
|-- foo_c: integer (nullable = true)
|-- bar_a: float (nullable = true)
|-- bar_b: float (nullable = true)
|-- bar_c: integer (nullable = true)
I generalized the solution from stecos a bit more so the flattening can be done on more than two struct layers deep:
def flatten_df(nested_df, layers):
flat_cols = []
nested_cols = []
flat_df = []
flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])
flat_df.append(nested_df.select(flat_cols[0] +
[col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols[0]
for c in nested_df.select(nc+'.*').columns])
)
for i in range(1, layers):
print (flat_cols[i-1])
flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])
flat_df.append(flat_df[i-1].select(flat_cols[i] +
[col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols[i]
for c in flat_df[i-1].select(nc+'.*').columns])
)
return flat_df[-1]
just call with:
my_flattened_df = flatten_df(my_df_having_nested_structs, 3)
(second parameter is the level of layers to be flattened, in my case it's 3)
An easy way is to use SQL, you could build a SQL query string to alias nested column as flat ones.
Here is an example in Java.
(I prefer SQL way, so you can easily test it on Spark-shell and it's cross-language).
This flatten_df
version flattens the dataframe at every layer level, using a stack to avoid recursive calls:
from pyspark.sql.functions import col
def flatten_df(nested_df):
stack = [((), nested_df)]
columns = []
while len(stack) > 0:
parents, df = stack.pop()
flat_cols = [
col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
for c in df.dtypes
if c[1][:6] != "struct"
]
nested_cols = [
c[0]
for c in df.dtypes
if c[1][:6] == "struct"
]
columns.extend(flat_cols)
for nested_col in nested_cols:
projected_df = df.select(nested_col + ".*")
stack.append((parents + (nested_col,), projected_df))
return nested_df.select(columns)
Example:
from pyspark.sql.types import StringType, StructField, StructType
schema = StructType([
StructField("some", StringType()),
StructField("nested", StructType([
StructField("nestedchild1", StringType()),
StructField("nestedchild2", StringType())
])),
StructField("renested", StructType([
StructField("nested", StructType([
StructField("nestedchild1", StringType()),
StructField("nestedchild2", StringType())
]))
]))
])
data = [
{
"some": "value1",
"nested": {
"nestedchild1": "value2",
"nestedchild2": "value3",
},
"renested": {
"nested": {
"nestedchild1": "value4",
"nestedchild2": "value5",
}
}
}
]
df = spark.createDataFrame(data, schema)
flat_df = flatten_df(df)
print(flat_df.collect())
Prints:
[Row(some=u'value1', renested_nested_nestedchild1=u'value4', renested_nested_nestedchild2=u'value5', nested_nestedchild1=u'value2', nested_nestedchild2=u'value3')]