Spark extracting values from a Row

2019-01-02 20:58发布

问题:

I have the following dataframe

val transactions_with_counts = sqlContext.sql(
  """SELECT user_id AS user_id, category_id AS category_id,
  COUNT(category_id) FROM transactions GROUP BY user_id, category_id""")

I'm trying to convert the rows to Rating objects but since x(0) returns an array this fails

val ratings = transactions_with_counts
  .map(x => Rating(x(0).toInt, x(1).toInt, x(2).toInt))

error: value toInt is not a member of Any

回答1:

Lets start with some dummy data:

val transactions = Seq((1, 2), (1, 4), (2, 3)).toDF("user_id", "category_id")

val transactions_with_counts = transactions
  .groupBy($"user_id", $"category_id")
  .count

transactions_with_counts.printSchema

// root
// |-- user_id: integer (nullable = false)
// |-- category_id: integer (nullable = false)
// |-- count: long (nullable = false)

There are a few ways to access Row values and keep expected types:

  1. Pattern matching

    import org.apache.spark.sql.Row
    
    transactions_with_counts.map{
      case Row(user_id: Int, category_id: Int, rating: Long) =>
        Rating(user_id, category_id, rating)
    } 
    
  2. Typed get* methods like getInt, getLong:

    transactions_with_counts.map(
      r => Rating(r.getInt(0), r.getInt(1), r.getLong(2))
    )
    
  3. getAs method which can use both names and indices:

    transactions_with_counts.map(r => Rating(
      r.getAs[Int]("user_id"), r.getAs[Int]("category_id"), r.getAs[Long](2)
    ))
    

    It can be used to properly extract user defined types, including mllib.linalg.Vector. Obviously accessing by name requires a schema.

  4. Converting to statically typed Dataset (Spark 1.6+ / 2.0+):

    transactions_with_counts.as[(Int, Int, Long)]
    


回答2:

Using Datasets you can define Ratings as follows:

case class Rating(user_id: Int, category_id:Int, count:Long)

The Rating class here has a column name 'count' instead of 'rating' as zero323 suggested. Thus the rating variable is assigned as follows:

val transactions_with_counts = transactions.groupBy($"user_id", $"category_id").count

val rating = transactions_with_counts.as[Rating]

This way you will not run into run-time errors in Spark because your Rating class column name is identical to the 'count' column name generated by Spark on run-time.



回答3:

To access a value of a row of Dataframe, you need to use rdd.collect of Dataframe with for loop.

Consider your Dataframe looks like below.

val df = Seq(
      (1,"James"),    
      (2,"Albert"),
      (3,"Pete")).toDF("user_id","name")

Use rdd.collect on top of your Dataframe. The row variable will contain each row of Dataframe of rdd row type. To get each element from a row, use row.mkString(",") which will contain value of each row in comma separated values. Using split function (inbuilt function) you can access each column value of rdd row with index.

for (row <- df.rdd.collect)
{   
    var user_id = row.mkString(",").split(",")(0)
    var category_id = row.mkString(",").split(",")(1)       
}

The above code looks little more bigger when compared to dataframe.foreach loops, but you will get more control over your logic by using the above code.



标签: