Save Apache Spark mllib model in python [duplicate

2020-03-03 05:36发布

问题:

I am trying to save a fitted model to a file in Spark. I have a Spark cluster which trains a RandomForest model. I would like to save and reuse the fitted model on another machine. I read some posts on the web which recommends to do java serialization. I am doing the equivalent in python but it does not work. What is the trick?

model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={},
                                    numTrees=nb_tree,featureSubsetStrategy="auto",
                                    impurity='variance', maxDepth=depth)
output = open('model.ml', 'wb')
pickle.dump(model,output)

I am getting this error:

TypeError: can't pickle lock objects

I am using Apache Spark 1.2.0.

回答1:

If you look at the source code, you'll see that the RandomForestModel inherits from the TreeEnsembleModel which in turn inherits from JavaSaveable class that implements the save() method, so you can save your model like in the example below:

model.save([spark_context], [file_path])

So it will save the model into the file_path using the spark_context. You cannot use (at least until now) the Python nativle pickle to do that. If you really want to do that, you'll need to implement the methods __getstate__ or __setstate__ manually. See this pickle documentation for more information.