I want to sort the Double values in a RDD and I want my sort function to ignore the Double.NaN values.
Either the Double.NaN values should appear at the bottom or top of the sorted RDD.
I was not able to achieve this using sortBy.
scala> res13.sortBy(r => r, ascending = true)
res21: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[10] at sortBy at <console>:26
scala> res21.collect.foreach(println)
0.656
0.99
0.998
1.0
NaN
5.6
7.0
scala> res13.sortBy(r => r, ascending = false)
res23: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[15] at sortBy at <console>:26
scala> res23.collect.foreach(println)
7.0
5.6
NaN
1.0
0.998
0.99
0.656
My expected result is
scala> res23.collect.foreach(println)
7.0
5.6
1.0
0.998
0.99
0.656
NaN
or
scala> res21.collect.foreach(println)
NaN
0.656
0.99
0.998
1.0
5.6
7.0
Taking what I said in the comment, you can try this:
scala> val a = sc.parallelize(Array(0.656, 0.99, 0.998, 1.0, Double.NaN, 5.6, 7.0))
a: org.apache.spark.rdd.RDD[Double] = ParallelCollectionRDD[0] at parallelize at <console>:24
scala> a.sortBy(r => r, ascending = false).collect
res2: Array[Double] = Array(7.0, 5.6, NaN, 1.0, 0.998, 0.99, 0.656)
scala> a.sortBy(r => if (r.isNaN) Double.MinValue else r, ascending = false).collect
res3: Array[Double] = Array(7.0, 5.6, 1.0, 0.998, 0.99, 0.656, NaN)
scala> a.sortBy(r => if (r.isNaN) Double.MaxValue else r, ascending = false).collect
res4: Array[Double] = Array(NaN, 7.0, 5.6, 1.0, 0.998, 0.99, 0.656)
To add on @user3685285 's answer :
scala> def sortAscending(r: Double): Double = { if (r.isNaN) Double.MaxValue else r }
sortAscending: (r: Double)Double
scala> def sortDescending(r: Double): Double = {if (r.isNaN) Double.MinValue else r }
sortDescending: (r: Double)Double
scala> res0.sortBy(sortDescending, ascending=false)
res7: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[20] at sortBy at <console>:28
scala> res7.collect.foreach(println)
99.9
34.2
10.98
7.0
6.0
5.0
2.0
0.56
0.01
0.0
NaN
NaN
scala> res0.sortBy(sortAscending, ascending=true)
res9: org.apache.spark.rdd.RDD[Double] = MapPartitionsRDD[25] at sortBy at <console>:28
scala> res9.collect.foreach(println)
0.0
0.01
0.56
2.0
5.0
6.0
7.0
10.98
34.2
99.9
NaN
NaN