How to update column based on a condition (a value

2019-03-15 15:41发布

问题:

I have the following df:

+---+----+-----+
|sno|dept|color|
+---+----+-----+
|  1|  fn|  red|
|  2|  fn| blue|
|  3|  fn|green|
+---+----+-----+

If any of the color column values is red, then I all values of the color column should be updated to be red, as below:

+---+----+-----+
|sno|dept|color|
+---+----+-----+
|  1|  fn|  red|
|  2|  fn|  red|
|  3|  fn|  red|
+---+----+-----+

I could not figure it out. Please help; I have tried following code:

val gp=jdbcDF.filter($"dept".contains("fn"))
     //.withColumn("newone",when($"dept"==="fn","RED").otherwise("NULL"))
    gp.show()
gp.map(
  row=>{
    val row1=row.getAs[String](1)
    var row2=row.getAs[String](2)
    val make=if(row1 =="fn") row2="red"
    Row(row(0),row(1),make)
  }
).collect().foreach(println)

回答1:

Given:

val df = Seq(
  (1, "fn", "red"),
  (2, "fn", "blue"),
  (3, "fn", "green"),
  (4, "aa", "blue"),
  (5, "aa", "green"),
  (6, "bb", "red"),
  (7, "bb", "red"),
  (8, "aa", "blue")
).toDF("id", "fn", "color")

Do the calculation:

val redOrNot = df.groupBy("fn")
  .agg(collect_set('color) as "values")
  .withColumn("hasRed", array_contains('values, "red"))

// gives null for no option
val colorPicker = when('hasRed, "red")
val result = df.join(redOrNot, "fn")
  .withColumn("resultColor", colorPicker) 
  .withColumn("color", coalesce('resultColor, 'color)) // skips nulls that leads to the answer
  .select('id, 'fn, 'color)

The result looks as follows (that seems to be an answer):

scala> result.show
+---+---+-----+
| id| fn|color|
+---+---+-----+
|  1| fn|  red|
|  2| fn|  red|
|  3| fn|  red|
|  4| aa| blue|
|  5| aa|green|
|  6| bb|  red|
|  7| bb|  red|
|  8| aa| blue|
+---+---+-----+

You can chain when operators and have a default value with otherwise. Consult the scaladoc of when operator.

I think you could do something very similar (and perhaps more efficient) using windowed operators or user-defined aggregate functions (UDAF), but...well...don't currently know how to do it. Leaving the comment here to inspire others ;-)

p.s. Learnt a lot! Thanks for the idea!



回答2:

Efficient solution which doesn't require expensive grouping:

// All groups with `red`
df.where($"color" === "red").select($"fn".alias("fn_")).distinct
  // Join with input
  .join(df.as("df"), $"fn" === $"fn_", "rightouter")
  // Replace `color`
  .withColumn("color", when($"fn_"isNull, $"color").otherwise(lit("red")))
  .drop("fn_")


回答3:

You are conditionally updating the DataFrame if it satisfies a certain property. In this case the property is "the color column contains 'red'". The idiomatic way to express this is to filter with the desired predicate and then determine whether any rows satisfy it. There is no need for a join.

import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.DataFrame

def makeAllRedIfAnyAreRed(df: DataFrame) = {
    val containsRed = df.filter(df("color") === "red").count() > 0
    if (containsRed) df.withColumn("color", lit("red")) else df
}


回答4:

As there could be few rows in filtered dataframe I'm adding solution with isin() and .withColumn() combination.

Sample DataFrame

val df = Seq(
  (1, "fn", "red"),
  (2, "fn", "blue"),
  (3, "fn", "green"),
  (4, "aa", "blue"),
  (5, "aa", "green"),
  (6, "bb", "red"),
  (7, "bb", "red"),
  (8, "aa", "blue")
).toDF("id", "dept", "color")

Now Let's pick only depts which have at least one red color row and place it in broadcast variable like below.

val depts = sc.broadcast(df.filter($"color" === "red").select(collect_set("dept")).first.getSeq[String](0)))

Update red color for filtered depts records.

isin() takes a vararg so convert list to vararg (depts.value:_*)

//creating new column by giving diff name (clr) to see the diff
val result = df.withColumn("clr", when($"dept".isin(depts.value:_*),lit("red"))
                    .otherwise($"color"))

result.show()

+---+----+-----+-----+
| id|dept|color|  clr|
+---+----+-----+-----+
|  1|  fn|  red|  red|
|  2|  fn| blue|  red|
|  3|  fn|green|  red|
|  4|  aa| blue| blue|
|  5|  aa|green|green|
|  6|  bb|  red|  red|
|  7|  bb|  red|  red|
|  8|  aa| blue| blue|
+---+----+-----+-----+


回答5:

Spark 2.2.0: Sample Dataframe ( taken from above examples)

    val df = Seq(
  (1, "fn", "red"),
  (2, "fn", "blue"),
  (3, "fn", "green"),
  (4, "aa", "blue"),
  (5, "aa", "green"),
  (6, "bb", "red"),
  (7, "bb", "red"),
  (8, "aa", "blue")
).toDF("id", "dept", "color")

created a UDF to perform the update by checking the condition.

val replace_val = udf((x: String,y:String) => if (Option(x).getOrElse("").equalsIgnoreCase("fn")&&(!y.equalsIgnoreCase("red"))) "red" else y)

val final_df = df.withColumn("color", replace_val($"dept",$"color"))
final_df.show()

output:

spark 1.6:

val conf = new SparkConf().setMaster("local").setAppName("My app")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)

import sqlContext.implicits._
// For implicit conversions like converting RDDs to DataFrames
val df = sc.parallelize(Seq(
  (1, "fn", "red"),
  (2, "fn", "blue"),
  (3, "fn", "green"),
  (4, "aa", "blue"),
  (5, "aa", "green"),
  (6, "bb", "red"),
  (7, "bb", "red"),
  (8, "aa", "blue")
) ).toDF("id","dept","color")


val replace_val = udf((x: String,y:String) => if (Option(x).getOrElse("").equalsIgnoreCase("fn")&&(!y.equalsIgnoreCase("red"))) "red" else y)
val final_df = df.withColumn("color", replace_val($"dept",$"color"))

final_df.show()