I have the following dataframe
val transactions_with_counts = sqlContext.sql(
"""SELECT user_id AS user_id, category_id AS category_id,
COUNT(category_id) FROM transactions GROUP BY user_id, category_id""")
I'm trying to convert the rows to Rating objects but since x(0) returns an array this fails
val ratings = transactions_with_counts
.map(x => Rating(x(0).toInt, x(1).toInt, x(2).toInt))
error: value toInt is not a member of Any
Lets start with some dummy data:
val transactions = Seq((1, 2), (1, 4), (2, 3)).toDF("user_id", "category_id")
val transactions_with_counts = transactions
.groupBy($"user_id", $"category_id")
.count
transactions_with_counts.printSchema
// root
// |-- user_id: integer (nullable = false)
// |-- category_id: integer (nullable = false)
// |-- count: long (nullable = false)
There are a few ways to access Row
values and keep expected types:
Pattern matching
import org.apache.spark.sql.Row
transactions_with_counts.map{
case Row(user_id: Int, category_id: Int, rating: Long) =>
Rating(user_id, category_id, rating)
}
Typed get*
methods like getInt
, getLong
:
transactions_with_counts.map(
r => Rating(r.getInt(0), r.getInt(1), r.getLong(2))
)
getAs
method which can use both names and indices:
transactions_with_counts.map(r => Rating(
r.getAs[Int]("user_id"), r.getAs[Int]("category_id"), r.getAs[Long](2)
))
It can be used to properly extract user defined types, including mllib.linalg.Vector
. Obviously accessing by name requires a schema.
Converting to statically typed Dataset
(Spark 1.6+ / 2.0+):
transactions_with_counts.as[(Int, Int, Long)]
Using Datasets you can define Ratings as follows:
case class Rating(user_id: Int, category_id:Int, count:Long)
The Rating class here has a column name 'count' instead of 'rating' as zero323 suggested. Thus the rating variable is assigned as follows:
val transactions_with_counts = transactions.groupBy($"user_id", $"category_id").count
val rating = transactions_with_counts.as[Rating]
This way you will not run into run-time errors in Spark because your
Rating class column name is identical to the 'count' column name generated by Spark on run-time.
To access a value of a row of Dataframe, you need to use rdd.collect
of Dataframe with for loop.
Consider your Dataframe looks like below.
val df = Seq(
(1,"James"),
(2,"Albert"),
(3,"Pete")).toDF("user_id","name")
Use rdd.collect
on top of your Dataframe. The row
variable will contain each row of Dataframe of rdd
row type. To get each element from a row, use row.mkString(",")
which will contain value of each row in comma separated values. Using split
function (inbuilt function) you can access each column value of rdd
row with index.
for (row <- df.rdd.collect)
{
var user_id = row.mkString(",").split(",")(0)
var category_id = row.mkString(",").split(",")(1)
}
The above code looks little more bigger when compared to dataframe.foreach
loops, but you will get more control over your logic by using the above code.