In Scala/Spark, having a dataframe:
val dfIn = sqlContext.createDataFrame(Seq(
("r0", 0, 2, 3),
("r1", 1, 0, 0),
("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2")
I would like to compute a new column maxCol
holding the name of the column corresponding to the max value (for each row). With this example, the output should be:
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0| 0| 2| 3| c2|
| r1| 1| 0| 0| c0|
| r2| 0| 2| 2| c1|
+---+---+---+---+------+
Actually the dataframe have more than 60 columns. Thus a generic solution is required.
The equivalent in Python Pandas (yes, I know, I should compare with pyspark...) could be:
dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1)
With a small trick you can use
greatest
function. Required imports:First let's create a list of
structs
, where the first element is value, and the second one column name:Structure like this can be passed to
greatest
as follows:Please note that in case of ties it will take the element which occurs later in the sequence (lexicographically
(x, "c2") > (x, "c1")
). If for some reason this is not acceptable you can explicitly reduce withwhen
:In case of
nullable
columns you have to adjust this, for example bycoalescing
to values to-Inf
.