Spark - Sort Double values in an RDD and ignore Na

2019-06-11 03:21发布

问题:

I want to sort the Double values in a RDD and I want my sort function to ignore the Double.NaN values.

Either the Double.NaN values should appear at the bottom or top of the sorted RDD.

I was not able to achieve this using sortBy.

scala> res13.sortBy(r => r, ascending = true)
res21: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[10] at sortBy at <console>:26

scala> res21.collect.foreach(println)
0.656
0.99
0.998
1.0
NaN
5.6
7.0

scala> res13.sortBy(r => r, ascending = false)
res23: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[15] at sortBy at <console>:26

scala> res23.collect.foreach(println)
7.0
5.6
NaN
1.0
0.998
0.99
0.656

My expected result is

scala> res23.collect.foreach(println)
    7.0
    5.6
    1.0
    0.998
    0.99
    0.656
    NaN

or 
    scala> res21.collect.foreach(println)
    NaN
    0.656
    0.99
    0.998
    1.0
    5.6
    7.0

回答1:

Taking what I said in the comment, you can try this:

scala> val a = sc.parallelize(Array(0.656, 0.99, 0.998, 1.0, Double.NaN, 5.6, 7.0))
a: org.apache.spark.rdd.RDD[Double] = ParallelCollectionRDD[0] at parallelize at <console>:24

scala> a.sortBy(r => r, ascending = false).collect
res2: Array[Double] = Array(7.0, 5.6, NaN, 1.0, 0.998, 0.99, 0.656)

scala> a.sortBy(r => if (r.isNaN) Double.MinValue else r, ascending = false).collect
res3: Array[Double] = Array(7.0, 5.6, 1.0, 0.998, 0.99, 0.656, NaN)

scala> a.sortBy(r => if (r.isNaN) Double.MaxValue else r, ascending = false).collect
res4: Array[Double] = Array(NaN, 7.0, 5.6, 1.0, 0.998, 0.99, 0.656)


回答2:

To add on @user3685285 's answer :

scala> def sortAscending(r: Double): Double = { if (r.isNaN) Double.MaxValue else r }
sortAscending: (r: Double)Double

scala> def sortDescending(r: Double): Double = {if (r.isNaN) Double.MinValue else r }
sortDescending: (r: Double)Double

scala> res0.sortBy(sortDescending, ascending=false)
res7: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[20] at sortBy at <console>:28

scala> res7.collect.foreach(println)
99.9
34.2
10.98
7.0
6.0
5.0
2.0
0.56
0.01
0.0
NaN
NaN

scala> res0.sortBy(sortAscending, ascending=true)
res9: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[25] at sortBy at <console>:28

scala> res9.collect.foreach(println)
0.0
0.01
0.56
2.0
5.0
6.0
7.0
10.98
34.2
99.9
NaN
NaN