Spark KMeans clustering: get the number of sample

2019-02-18 07:42发布

I am using Spark Mlib for kmeans clustering. I have a set of vectors from which I want to determine the most likely cluster center. So I will run kmeans clustering training on this set and select the cluster with the highest number of vector assigned to it.

Therefore I need to know the number of vectors assigned to each cluster after training (i.e KMeans.run(...)). But I can not find a way to retrieve this information from KMeanModel result. I probably need to run predict on all training vectors and count the label which appear the most.

Is there another way to do this?

Thank you

1条回答
地球回转人心会变
2楼-- · 2019-02-18 08:34

You are right, this info is not provided by the model, and you have to run predict. Here is an example of doing so in a parallelized way (Spark v. 1.5.1):

 from pyspark.mllib.clustering import KMeans
 from numpy import array
 data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0, 10.0, 9.0]).reshape(5, 2)
 data
 # array([[  0.,   0.],
 #       [  1.,   1.],
 #       [  9.,   8.],
 #       [  8.,   9.],
 #       [ 10.,   9.]])

 k = 2 # no. of clusters
 model = KMeans.train(
                sc.parallelize(data), k, maxIterations=10, runs=30, initializationMode="random",
                seed=50, initializationSteps=5, epsilon=1e-4)

 cluster_ind = model.predict(sc.parallelize(data))
 cluster_ind.collect()
 # [1, 1, 0, 0, 0]

cluster_ind is an RDD of the same cardinality with our initial data, and it shows which cluster each datapoint belongs to. So, here we have two clusters, one with 3 datapoints (cluster 0) and one with 2 datapoints (cluster 1). Notice that we have run the prediction method in a parallel fashion (i.e. on an RDD) - collect() is used here only for our demonstration purposes, and it is not needed in a 'real' situation.

Now, we can get the cluster sizes with

 cluster_sizes = cluster_ind.countByValue().items()
 cluster_sizes
 # [(0, 3), (1, 2)]

From this, we can get the maximum cluster index & size as

 from operator import itemgetter
 max(cluster_sizes, key=itemgetter(1))
 # (0, 3)

i.e. our biggest cluster is cluster 0, with a size of 3 datapoints, which can be easily verified by inspection of cluster_ind.collect() above.

查看更多
登录 后发表回答