I have a dataset like this:
uid group_a group_b
1 3 unkown
1 unkown 4
2 unkown 3
2 2 unkown
I want to get the result:
uid group_a group_b
1 3 4
2 2 3
I try to group the data by "uid" and iterate each group and select the not-unkown value as the final value, but don't know how to do it.
After you format the dataset to a PairRDD you can use the reduceByKey operation to find the single known value. The following example assumes that there is only one known value per uid or otherwise returns the first known value
val input = List(
("1", "3", "unknown"),
("1", "unknown", "4"),
("2", "unknown", "3"),
("2", "2", "unknown")
)
val pairRdd = sc.parallelize(input).map(l => (l._1, (l._2, l._3)))
val result = pairRdd.reduceByKey { (a, b) =>
val groupA = if (a._1 != "unknown") a._1 else b._1
val groupB = if (a._2 != "unknown") a._2 else b._2
(groupA, groupB)
}
The result will be a pairRdd that looks like this
(uid, (group_a, group_b))
(1,(3,4))
(2,(2,3))
You can return to the plain line format with a simple map operation.
I would suggest you define a User Defined Aggregation Function
(UDAF
)
Using inbuilt functions
are great ways but they are difficult to be customized. If you own a UDAF
then it is customizable and you can edit it according to your needs.
Concerning your problem, following can be your solution. You can edit it according to your needs.
First task is to define a UDAF
class PingJiang extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("group_a", StringType).add("group_b", StringType)
def bufferSchema = new StructType().add("buff0", StringType).add("buff1", StringType)
def dataType = StringType
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, "")
buffer.update(1, "")
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0)) {
val buff = buffer.getString(0)
val groupa = input.getString(0)
val groupb = input.getString(1)
if(!groupa.equalsIgnoreCase("unknown")){
buffer.update(0, groupa)
}
if(!groupb.equalsIgnoreCase("unknown")){
buffer.update(1, groupb)
}
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val buff1 = buffer1.getString(0)+buffer2.getString(0)
val buff2 = buffer1.getString(1)+buffer2.getString(1)
buffer1.update(0, buff1+","+buff2)
}
def evaluate(buffer: Row) : String = {
buffer.getString(0)
}
}
Then you call it from your main
class and do some manipulations to get the result you need as
val data = Seq(
(1, "3", "unknown"),
(1, "unknown", "4"),
(2, "unknown", "3"),
(2, "2", "unknown"))
.toDF("uid", "group_a", "group_b")
val udaf = new PingJiang()
val result = data.groupBy("uid").agg(udaf($"group_a", $"group_b").as("ping"))
.withColumn("group_a", split($"ping", ",")(0))
.withColumn("group_b", split($"ping", ",")(1))
.drop("ping")
result.show(false)
Visit databricks and augmentiq for better understanding of UDAF
Note : The above solution gets you the latest value for each group if present (You can always edit according to your needs)
You could replace all "unknown"
values by null
, and then use the function first()
inside a map (as shown here), to get the first non-null values in each column per group:
import org.apache.spark.sql.functions.{col,first,when}
// We are only gonna apply our function to the last 2 columns
val cols = df.columns.drop(1)
// Create expression
val exprs = cols.map(first(_,true))
// Putting it all together
df.select(df.columns
.map(c => when(col(c) === "unknown", null)
.otherwise(col(c)).as(c)): _*)
.groupBy("uid")
.agg(exprs.head, exprs.tail: _*).show()
+---+--------------------+--------------------+
|uid|first(group_1, true)|first(group_b, true)|
+---+--------------------+--------------------+
| 1| 3| 4|
| 2| 2| 3|
+---+--------------------+--------------------+
Data:
val df = sc.parallelize(Array(("1","3","unknown"),("1","unknown","4"),
("2","unknown","3"),("2","2","unknown"))).toDF("uid","group_1","group_b")