I have created an RDD with each member being a key value pair with the key being a DenseVector
and value being an int
. e.g.
[(DenseVector([3,4]),10), (DenseVector([3,4]),20)]
Now I want to group by the key k1
: DenseVector([3,4])
. I expect the behaviour to be grouping all the values of the key k1
which are 10
and 20
. But the result I get is
[(DenseVector([3,4]), 10), (DenseVector([3,4]), 20)]
instead of
[(DenseVector([3,4]), [10,20])]
Please let me know if I am missing something.
The code for the same is :
#simplified version of code
#rdd1 is an rdd containing [(DenseVector([3,4]),10), (DenseVector([3,4]),20)]
rdd1.groupByKey().map(lambda x : (x[0], list(x[1])))
print(rdd1.collect())
Well, thats a tricky question and short answer is you can't. To understand why you'll have to dig deeper into DenseVector
implementation. DenseVector
is simply a wrapper around NumPy float64
ndarray
>>> dv1 = DenseVector([3.0, 4.0])
>>> type(dv1.array)
<type 'numpy.ndarray'>
>>> dv1.array.dtype
dtype('float64')
Since NumPy ndarrays
, unlike DenseVector
are mutable cannot be hashed in a meaningful way, although what is interesting provide __hash__
method. There is an interesting question which covers this issue (see: numpy ndarray hashability).
>>> dv1.array.__hash__() is None
False
>>> hash(dv1.array)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'numpy.ndarray'
DenseVector
inherits __hash__
method from object
and it is simply based on an id
(memory address of a given instance):
>>> id(d1) / 16 == hash(d1)
True
Unfortunately it means that two DenseVectors
with the same content have different hashes:
>>> dv2 = DenseVector([3.0, 4.0])
>>> hash(dv1) == hash(dv2)
False
What can you do? The simplest thing is to use an immutable data structure which provides consistent hash
implementation, for example tuple:
rdd.groupBy(lambda (k, v): tuple(k))
Note: In practice using arrays as a key is most likely a bad idea. With large number of elements hashing process can be far to expensive to be useful. Still, if you really need something like this Scala seems to work just fine:
import org.apache.spark.mllib.linalg.Vectors
val rdd = sc.parallelize(
(Vectors.dense(3, 4), 10) :: (Vectors.dense(3, 4), 20) :: Nil)
rdd.groupByKey.collect