Spark dataframes: Extract a column based on the va

2019-05-18 01:00发布

问题:

I have a dataframe with transactions with a joined price list:

+----------+----------+------+-------+-------+
|   paid   | currency | EUR  |  USD  |  GBP  |
+----------+----------+------+-------+-------+
|   49.5   |   EUR    | 99   |  79   |  69   |
+----------+----------+------+-------+-------+

A customer has paid 49.5 in EUR, as shown in the "currency" column. I now want to compare that paid price with the price from the price list.

Therefor I need to access the correct column based on the value of "currency" like so:

df.withColumn("saved", df.col(df.col($"currency")) - df.col("paid"))

which I hoped would become

df.withColumn("saved", df.col("EUR") - df.col("paid"))

This fails, however. I tried all things I could image, including and UDF, getting nowhere.

I guess there is some elegant solution for this? Can somebody help out here?

回答1:

Assuming that the column names match values in the currency column:

import org.apache.spark.sql.functions.{lit, col, coalesce}
import org.apache.spark.sql.Column 

// Dummy data
val df = sc.parallelize(Seq(
  (49.5, "EUR", 99, 79, 69), (100.0, "GBP", 80, 120, 50)
)).toDF("paid", "currency", "EUR", "USD", "GBP")

// A list of available currencies 
val currencies: List[String] = List("EUR", "USD", "GBP")

// Select listed value
val listedPrice: Column = coalesce(
  currencies.map(c => when($"currency" === c, col(c)).otherwise(lit(null))): _*)

df.select($"*", (listedPrice - $"paid").alias("difference")).show

// +-----+--------+---+---+---+----------+
// | paid|currency|EUR|USD|GBP|difference|
// +-----+--------+---+---+---+----------+
// | 49.5|     EUR| 99| 79| 69|      49.5|
// |100.0|     GBP| 80|120| 50|     -50.0|
// +-----+--------+---+---+---+----------+

with SQL equivalent of listedPrice expression being something like this:

COALESCE(
  CASE WHEN (currency = 'EUR') THEN EUR ELSE null,
  CASE WHEN (currency = 'USD') THEN USD ELSE null,
  CASE WHEN (currency = 'GBP') THEN GBP ELSE null
)

Alternative using foldLeft:

import org.apache.spark.sql.functions.when

val listedPriceViaFold = currencies.foldLeft(
  lit(null))((acc, c) => when($"currency" === c, col(c)).otherwise(acc))

df.select($"*", (listedPriceViaFold - $"paid").alias("difference")).show

// +-----+--------+---+---+---+----------+
// | paid|currency|EUR|USD|GBP|difference|
// +-----+--------+---+---+---+----------+
// | 49.5|     EUR| 99| 79| 69|      49.5|
// |100.0|     GBP| 80|120| 50|     -50.0|
// +-----+--------+---+---+---+----------+

where listedPriceViaFold translates to following SQL:

CASE
  WHEN (currency = 'GBP') THEN GBP
  ELSE CASE
    WHEN (currency = 'USD') THEN USD
    ELSE CASE
      WHEN (currency = 'EUR') THEN EUR
      ELSE null

Unfortunately I am not aware of any built-in functions which could express directly SQL like this

CASE currency
    WHEN 'EUR' THEN EUR
    WHEN 'USD' THEN USD
    WHEN 'GBP' THEN GBP
    ELSE null
END

but you can use this construct in raw SQL.

It my assumption is not true you can simply add mapping between column name and a value in the currency column.

Edit:

Another option, which could be efficient if source supports predicate pushdown and efficient column pruning, is to subset by currency and union:

currencies.map(
  // for each currency filter and add difference
  c => df.where($"currency" === c).withColumn("difference", $"paid" - col(c))
).reduce((df1, df2) => df1.unionAll(df2)) // Union

It is equivalent to SQL like this:

SELECT *,  EUR - paid AS difference FROM df WHERE currency = 'EUR'
UNION ALL
SELECT *,  USD - paid AS difference FROM df WHERE currency = 'USD'
UNION ALL
SELECT *,  GBP - paid AS difference FROM df WHERE currency = 'GBP'


回答2:

I can't think of a way doing this with DataFrames, and I doubt that there is simple way, but if you take that table into an RDD:

// On top of my head, warn if wrong.
// Would be more elegant with match .. case 
def d(l: (Int, String, Int, Int, Int)): Int = {
  if(l._2 == "EUR")
    l._3 - l._1
  else if (l._2 == "USD")
    l._4 - l._1
  else 
    l._5 -l._1
}
val rdd = df.rdd
val diff = rdd.map(r => (r, r(d)))

Will most likely raise type errors, I hope you can navigate around those.