I want to perform cluster analysis using K-Means
on itemFactors
produced by ALS
. Although the itemFactors of ALSModel
returns a dataframe that contains the id and the features of the itemFactors
, this data structure seems to be unsuitable for K-Means
.
Here's the code for collaborative filtering using ALS
:
val als = new ALS()
.setRegParam(0.01)
.setNonnegative(false)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
val model = als.fit(training)
val predictions = model.transform(testing)
val item_factors = model.itemFactors
item_factors
dataframe looks like
+---+-------------------------------------------------------------------------------------------------------------------------------------+
|id |features |
+---+-------------------------------------------------------------------------------------------------------------------------------------+
|10 |[-0.1317064, 0.07098049, -0.042259596, -0.28769347, 0.58783025, -0.33474237, 0.31248248, -0.34541374, 0.33257273, 0.06327486] |
|20 |[-0.0033912044, 0.31334892, -0.080896676, -0.75597364, -0.016326033, -0.34558973, 0.045129072, -0.38614395, -0.02269395, -0.16486467]|
|30 |[0.19784503, -0.313929, -0.67753965, -0.7700008, 0.08975326, -0.03427274, 0.49707127, 0.05604595, 0.078268416, 0.08767615] |
|40 |[0.29390565, -0.22765353, -0.9278744, -0.59953785, 0.184721, -0.061099682, 0.33711356, 0.094112396, 0.08261518, -0.30668002] |
|50 |[-0.4070981, -0.0013739555, -0.21247752, -0.3771588, 0.3029064, -0.3883846, 0.4752892, 0.30097932, 0.5130039, 0.2938855] |
|60 |[0.1413918, -0.074142076, -0.87392575, -0.07855377, -0.11006678, -0.44359666, 0.33419594, -0.16027139, -0.2440797, -0.1596081] |
|70 |[-0.26080364, -0.11437138, 0.046630252, -0.70999575, 0.014645281, -0.69176155, 0.05397229, -0.24038066, -0.429569, 0.5660369] |
|80 |[0.6104476, -0.35322133, -0.80230886, -0.5302148, -0.26538768, -0.25481275, 0.20784922, -0.10604211, 0.26007786, 0.47488773] |
|90 |[0.6976714, -0.5851011, -0.64844996, -0.82472694, 0.102610275, -0.45195442, 0.24074861, 0.2683314, 0.11396688, -0.52693856] |
|100|[-0.11564436, 0.21467225, -0.42873487, -0.54825515, 0.20628366, -0.28728506, 0.18303588, 0.11490151, -0.033433616, -0.08694091] |
|110|[-0.530162, 0.22694068, -0.30889827, -0.091455124, 0.52988344, -0.7247424, 0.029707031, 0.43658048, 0.21511139, -0.22376455] |
|120|[0.59780246, -0.3396686, -0.58882934, -0.11867501, -0.6055776, -0.82480395, -0.22715187, -0.4544479, 0.012708589, -0.22158282] |
|130|[0.9630984, -0.012603591, -0.37178686, -1.0995674, -0.57324636, -0.7460034, 1.2981551, 0.15384857, -1.0350431, -0.58156097] |
|140|[-0.1617866, 0.3927005, -0.26183906, -0.3666182, -0.015750444, -0.28372696, 0.3577147, -0.18155682, 0.22410324, -0.5632848] |
|150|[-0.20490485, 0.37170428, -0.47898963, 0.0686825, 0.31148073, -0.4663402, 0.2088939, -0.0071071014, 0.44748953, 0.0067634075] |
|160|[0.31892687, 0.30109385, -0.036033046, -0.58646286, 0.015361498, -0.5640331, 0.010378816, -0.52527076, -0.20914118, -0.07263985] |
|170|[0.13082151, -0.082676716, 0.15034986, -0.7333888, 0.14089121, -0.34780806, 0.51327425, -0.43825528, 0.2210635, -0.19778338] |
|180|[-0.45791233, -0.64516217, 0.3496911, -0.6879449, 0.11970334, -0.3473338, 0.30204558, -0.18284592, 0.5934964, 0.06711411] |
|190|[0.41464698, 0.04347724, -0.9297292, -1.2885705, -0.5567429, 0.2531382, 0.11184802, -0.46155334, -0.3385828, 0.789031] |
|200|[0.37707302, -0.023397477, -0.47769275, -0.99200153, -0.11546725, -0.125011, -0.07772487, -0.5624814, -0.026348682, -0.33438805] |
+---+-------------------------------------------------------------------------------------------------------------------------------------+
And here is the code for K-Means
clustering.
val kmeans = new KMeans().setK(10).setSeed(1L)
val kmeans_model = kmeans.fit(item_factors)
val predictions = kmeans_model.transform(item_factors)
The error I get when the item_factors
dataframe is fed into the K-Means
is shown below:
Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Column features must be of type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 but was actually ArrayType(FloatType,false).