I am new to Scala and I want to convert dataframe to rdd. let the label, features convert to RDD[labelPoint]
for the input of MLlib. But I can't find out the way to deal with WrappedArray
.
scala> test.printSchema
root
|-- user_id: long (nullable = true)
|-- brand_store_sn: string (nullable = true)
|-- label: integer (nullable = true)
|-- money_score: double (nullable = true)
|-- normal_score: double (nullable = true)
|-- action_score: double (nullable = true)
|-- features: array (nullable = true)
| |-- element: string (containsNull = true)
|-- flag: string (nullable = true)
|-- dt: string (nullable = true)
scala> test.head
res21: org.apache.spark.sql.Row = [2533,10005072,1,2.0,1.0,1.0,WrappedArray(["d90_pv_1sec:1.4471580313422192", "d3_pv_1sec:0.9030899869919435", "d7_pv_1sec:0.9030899869919435", "d30_pv_1sec:1.414973347970818", "d90_pv_week_decay:1.4235871662780681", "d1_pv_1sec:0.9030899869919435", "d120_pv_1sec:1.4471580313422192"]),user_positive,20161130]
First - since
LabeledPoint
expects a Vector ofDouble
s, I'm assuming you also want to split each element in everyfeatures
array by colon (:
), and treat the right-hand side of it as the double, e.g.:If so - here's the transformation:
EDIT: per clarification, now assuming each item in the input array contains the expected index in a resulting sparse vector:
The modified code would be:
NOTE: Using
s.replaceAll("d|_pv_1sec","")
might be a bit slow, as it compiles a regular expression for each item separately. If that's the case, it can be replaced by the faster (yet uglier)s.replace("d", "").replace("_pv_1sec", "")
which doesn't use regular expressions.