SPARK - Use RDD.foreach to Create a Dataframe and

2019-07-04 06:22发布

问题:

I am new to SPARK and figuring out a better way to achieve the following scenario. There is a database table containing 3 fields - Category, Amount, Quantity. First I try to pull all the distinct Categories from the database.

 val categories:RDD[String] = df.select(CATEGORY).distinct().rdd.map(r => r(0).toString)

Now for each category I want to execute the Pipeline which essentially creates dataframes from each category and apply some Machine Learning.

 categories.foreach(executePipeline)
 def execute(category: String): Unit = {
   val dfCategory = sqlCtxt.read.jdbc(JDBC_URL,"SELECT * FROM TABLE_NAME WHERE CATEGORY="+category)
dfCategory.show()    
}

Is it possible to do something like this ? Or is there any better alternative ?

回答1:

// You could get all your data with a single query and convert it to an rdd
val data = sqlCtxt.read.jdbc(JDBC_URL,"SELECT * FROM TABLE_NAME).rdd

// then group the data by category
val groupedData = data.groupBy(row => row.getAs[String]("category"))

// then you get an RDD[(String, Iterable[org.apache.spark.sql.Row])]
// and you can iterate over it and execute your pipeline
groupedData.map { case (categoryName, items) =>
  //executePipeline(categoryName, items)
}


回答2:

Your code would fail on a TaskNotSerializable exception since you're trying to use the SQLContext (which isn't serializable) inside the execute method, which should be serialized and sent to workers to be executed on each record in the categories RDD.

Assuming you know the number of categories is limited, which means the list of categories isn't too large to fit in your driver memory, you should collect the categories to driver, and iterate over that local collection using foreach:

val categoriesRdd: RDD[String] = df.select(CATEGORY).distinct().rdd.map(r => r(0).toString)
val categories: Seq[String] = categoriesRdd.collect()
categories.foreach(executePipeline)

Another improvement would be reusing the dataframe that you loaded instead of performing another query, using a filter for each category:

def executePipeline(singleCategoryDf: DataFrame) { /* ... */ }

categories.foreach(cat => {
  val filtered = df.filter(col(CATEGORY) === cat)
  executePipeline(filtered)
})

NOTE: to make sure the re-use of df doesn't reload it for every execution, make sure you cache() it before collecting the categories.