In Scala/Spark, having a dataframe:
val dfIn = sqlContext.createDataFrame(Seq(
("r0", 0, 2, 3),
("r1", 1, 0, 0),
("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2")
I would like to compute a new column maxCol
holding the name of the column corresponding to the max value (for each row). With this example, the output should be:
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0| 0| 2| 3| c2|
| r1| 1| 0| 0| c0|
| r2| 0| 2| 2| c1|
+---+---+---+---+------+
Actually the dataframe have more than 60 columns. Thus a generic solution is required.
The equivalent in Python Pandas (yes, I know, I should compare with pyspark...) could be:
dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1)
With a small trick you can use greatest
function. Required imports:
import org.apache.spark.sql.functions.{col, greatest, lit, struct}
First let's create a list of structs
, where the first element is value, and the second one column name:
val structs = dfIn.columns.tail.map(
c => struct(col(c).as("v"), lit(c).as("k"))
)
Structure like this can be passed to greatest
as follows:
dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k"))
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0| 0| 2| 3| c2|
| r1| 1| 0| 0| c0|
| r2| 0| 2| 2| c2|
+---+---+---+---+------+
Please note that in case of ties it will take the element which occurs later in the sequence (lexicographically (x, "c2") > (x, "c1")
). If for some reason this is not acceptable you can explicitly reduce with when
:
import org.apache.spark.sql.functions.when
val max_col = structs.reduce(
(c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2)
).getItem("k")
dfIn.withColumn("maxCol", max_col)
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0| 0| 2| 3| c2|
| r1| 1| 0| 0| c0|
| r2| 0| 2| 2| c1|
+---+---+---+---+------+
In case of nullable
columns you have to adjust this, for example by coalescing
to values to -Inf
.