Looking for expertise to guide me on issue below.
Background:
- I'm trying to get going with a basic PySpark script inspired on this example
- As deploy infrastructure I use a Google Cloud Dataproc Cluster.
- Cornerstone in my code is the function "recommendProductsForUsers" documented here which gives me back the top X products for all users in the model
Issue I incur
The ALS.Train script runs smoothly and scales well on GCP (Easily >1mn customers).
However, applying the predictions: i.e. using funcitons 'PredictAll' or 'recommendProductsForUsers', does not scale at all. My script runs smooth for a small dataset (<100 Customer with <100 products). However, when bringing it to a business-relevant size, I don't manage to scale it (e.g., >50k Customers and >10k products)
Error I then get is below:
16/08/16 14:38:56 WARN org.apache.spark.scheduler.TaskSetManager: Lost task 22.0 in stage 411.0 (TID 15139, productrecommendation-high-w-2.c.main-nova-558.internal): java.lang.StackOverflowError at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373) at scala.collection.immutable.$colon$colon.readObject(List.scala:362) at sun.reflect.GeneratedMethodAccessor11.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1909) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373) at scala.collection.immutable.$colon$colon.readObject(List.scala:362)
I even went as far as getting a 300 GB Cluster (1 main node of 108GB + 2 nodes of 108 GB RAM) to try it to run it; it works for 50k customers but not for anything more
Ambition is to have a setup where I can run for >800k customers
Details
Code line where it fails
predictions = model.recommendProductsForUsers(10).flatMap(lambda p: p[1]).map(lambda p: (str(p[0]), str(p[1]), float(p[2])))
pprint.pprint(predictions.take(10))
schema = StructType([StructField("customer", StringType(), True), StructField("sku", StringType(), True), StructField("prediction", FloatType(), True)])
dfToSave = sqlContext.createDataFrame(predictions, schema).dropDuplicates()
How do you suggest to proceed? I feel that the 'merging' part at the end of my script (i.e. when I write it to dfToSave) causes the error; is there a way to bypass this & save part-by-part?