I have some data in the following format (either RDD or Spark DataFrame):
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
rdd = sc.parallelize([('X01',41,'US',3),
('X01',41,'UK',1),
('X01',41,'CA',2),
('X02',72,'US',4),
('X02',72,'UK',6),
('X02',72,'CA',7),
('X02',72,'XX',8)])
# convert to a Spark DataFrame
schema = StructType([StructField('ID', StringType(), True),
StructField('Age', IntegerType(), True),
StructField('Country', StringType(), True),
StructField('Score', IntegerType(), True)])
df = sqlContext.createDataFrame(rdd, schema)
What I would like to do is to 'reshape' the data, convert certain rows in Country(specifically US, UK and CA) into columns:
ID Age US UK CA
'X01' 41 3 1 2
'X02' 72 4 6 7
Essentially, I need something along the lines of Python's pivot
workflow:
categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID',
columns = 'Country',
values = 'Score')
My dataset is rather large so I can't really collect()
and ingest the data into memory to do the reshaping in Python itself. Is there a way to convert Python's .pivot()
into an invokable function while mapping either an RDD or a Spark DataFrame? Any help would be appreciated!
There is a JIRA in Hive for PIVOT to do this natively, without a huge CASE statement for each value:
https://issues.apache.org/jira/browse/HIVE-3776
Please vote that JIRA up so it'll be implemented sooner. Once it in Hive SQL, Spark usually doesn't lack too much behind and eventually it'll be implemented in Spark as well.
First up, this is probably not a good idea, because you are not getting any extra information, but you are binding yourself with a fixed schema (ie you must need to know how many countries you are expecting, and of course, additional country means change in code)
Having said that, this is a SQL problem, which is shown below. But in case you suppose it is not too "software like" (seriously, I have heard this!!), then you can refer the first solution.
Solution 1:
Now, Solution 2: Of course better as SQL is right tool for this
data set up:
Result:
From 1st solution
From 2nd solution:
Kindly let me know if this works, or not :)
Best Ayan
Here's a native Spark approach that doesn't hardwire the column names. It's based on
aggregateByKey
, and uses a dictionary to collect the columns that appear for each key. Then we gather all the column names to create the final dataframe. [Prior version used jsonRDD after emitting a dictionary for each record, but this is more efficient.] Restricting to a specific list of columns, or excluding ones likeXX
would be an easy modification.The performance seems good even on quite large tables. I'm using a variation which counts the number of times that each of a variable number of events occurs for each ID, generating one column per event type. The code is basically the same except it uses a collections.Counter instead of a dict in the
seqFn
to count the occurrences.Produces:
Since Spark 1.6 you can use
pivot
function onGroupedData
and provide aggregate expression.Levels can be omitted but if provided can both boost performance and serve as an internal filter.
This method is still relatively slow but certainly beats manual passing data manually between JVM and Python.
Just some comments on the very helpful answer of patricksurry:
Here is the slightly modified code:
Finally, the output should be
So first off, I had to make this correction to your RDD (which matches your actual output):
Once I made that correction, this did the trick:
Not nearly as elegant as your pivot, for sure.