Convert from DataFrame to JavaPairRDD

2019-07-08 02:58发布

问题:

I'm trying to implement LDA algorithm using apache spark with Java API. Method LDA().run() accept parameter JavaPairRDD documents. I have use scala for create RDD[(Long, Vector)] follow:

val countVectors = cvModel.transform(filteredTokens)
    .select("docId", "features")
    .map { case Row(docId: Long, countVector: Vector) => (docId, countVector) }
    .cache()

And then input into LDA:

lda.run(countVectors)

But in Java API, I have CountVectorizerModel by using follow code:

CountVectorizerModel cvModel = new CountVectorizer()
        .setInputCol("filtered").setOutputCol("features")
        .setVocabSize(vocabSize).fit(filteredTokens);

look like that:

(0,(22,[0,8,9,10,14,16,18],
[1.0,1.0,1.0,1.0,1.0,1.0,1.0]))
(1,(22,[0,1,2,3,4,5,6,7,11,12,13,15,17,19,20,21],
1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]))

What should I do if I want to transform from cvModel into JavaPairRDD countVectors? I have try this:

JavaPairRDD<Long, Vector> countVectors = cvModel.transform(filteredTokens)
          .select("docId", "features").toJavaRDD()
          .mapToPair(new PairFunction<Row, Long, Vector>() {
            public Tuple2<Long, Vector> call(Row row) throws Exception {
                return new Tuple2<Long, Vector>(Long.parseLong(row.getString(0)), Vectors.dense(row.getDouble(1)));
            }
          }).cache();

But it does not work. I got exception when try:

Vectors.dense(row.getDouble(1))

So, If you have any ideal for convert from DataFrame cvModel into JavaPairRDD please tell me.

I am using Spark and MLlib 1.5.1, and Java8

Any help is highly appreciated. Thanks Here is exception log file when I try to convert from DataFrame into JavaPairRDD

15/10/25 10:03:07 ERROR Executor: Exception in task 0.0 in stage 7.0     (TID 6)
java.lang.ClassCastException: java.lang.Long cannot be cast to      java.lang.String
at org.apache.spark.sql.Row$class.getString(Row.scala:249)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getString(rows.scala:191)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:88)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:1)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.storage.MemoryStore.unrollSafely(MemoryStore.scala:278)
at org.apache.spark.CacheManager.putInBlockManager(CacheManager.scala:171)
at org.apache.spark.CacheManager.getOrCompute(CacheManager.scala:78)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:262)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:297)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:264)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
15/10/25 10:03:07 WARN TaskSetManager: Lost task 0.0 in stage 7.0 (TID 6, localhost): java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.String
at org.apache.spark.sql.Row$class.getString(Row.scala:249)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getString(rows.scala:191)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:88)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:1)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.storage.MemoryStore.unrollSafely(MemoryStore.scala:278)
at org.apache.spark.CacheManager.putInBlockManager(CacheManager.scala:171)
at org.apache.spark.CacheManager.getOrCompute(CacheManager.scala:78)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:262)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:297)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:264)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)

回答1:

Now that we have the error stack, here is the error:

You are trying to get a string from the row whereas your fields is a Long so you'll need to replace row.getString(0) by row.getLong(0) for starters.

Once you correct this, you'll run into other errors from the same type but on different levels, which I can point out with information given but you'll be able to solve them if you apply the following:

The row getters are specific for each field type, you'll need to use the correct get method.

To know the method you need to use if you are not sure, you can use the printSchema method on your DataFrame to check the types of each field and then you can all the type conversion described in the official documentation here.