How to Roll a Custom Estimator in PySpark mllib

2019-01-07 15:38发布

I am trying to build a simple custom Estimator in PySpark MLlib. I have here that it is possible to write a custom Transformer but I am not sure how to do it on an Estimator. I also don't understand what @keyword_only does and why do I need so many setters and getters. Scikit-learn seem to have a proper document for custom models (see here but PySpark doesn't.

Pseudo code of an example model:

class NormalDeviation():
    def __init__(self, threshold = 3):
    def fit(x, y=None):
       self.model = {'mean': x.mean(), 'std': x.std()]
    def predict(x):
       return ((x-self.model['mean']) > self.threshold * self.model['std'])
    def decision_function(x): # does ml-lib support this?

1条回答
爷的心禁止访问
2楼-- · 2019-01-07 16:07

Generally speaking there is no documentation because as for Spark 1.6 / 2.0 most of the related API is not intended to be public. It should change in Spark 2.1.0 (see SPARK-7146).

API is relatively complex because it has to follow specific conventions in order to make given Transformer or Estimator compatible with Pipeline API. Some of these methods may be required for features like reading and writing or grid search. Other, like keyword_only are just a simple helpers and not strictly required.

Assuming you have defined following mix-ins for mean parameter:

from pyspark.ml.pipeline import Estimator, Model, Pipeline
from pyspark.ml.param.shared import *
from pyspark.sql.functions import avg, stddev_samp


class HasMean(Params):

    mean = Param(Params._dummy(), "mean", "mean", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasMean, self).__init__()

    def setMean(self, value):
        return self._set(mean=value)

    def getMean(self):
        return self.getOrDefault(self.mean)

standard deviation parameter:

class HasStandardDeviation(Params):

    stddev = Param(Params._dummy(), "stddev", "stddev", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasStandardDeviation, self).__init__()

    def setStddev(self, value):
        return self._set(stddev=value)

    def getStddev(self):
        return self.getOrDefault(self.stddev)

and threshold:

class HasCenteredThreshold(Params):

    centered_threshold = Param(Params._dummy(),
            "centered_threshold", "centered_threshold",
            typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasCenteredThreshold, self).__init__()

    def setCenteredThreshold(self, value):
        return self._set(centered_threshold=value)

    def getCenteredThreshold(self):
        return self.getOrDefault(self.centered_threshold)

you could create basic Estimator as follows:

class NormalDeviation(Estimator, HasInputCol, 
        HasPredictionCol, HasCenteredThreshold):

    def _fit(self, dataset):
        c = self.getInputCol()
        mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first()
        return (NormalDeviationModel()
            .setInputCol(c)
            .setMean(mu)
            .setStddev(sigma)
            .setCenteredThreshold(self.getCenteredThreshold())
            .setPredictionCol(self.getPredictionCol()))

class NormalDeviationModel(Model, HasInputCol, HasPredictionCol,
        HasMean, HasStandardDeviation, HasCenteredThreshold):

    def _transform(self, dataset):
        x = self.getInputCol()
        y = self.getPredictionCol()
        threshold = self.getCenteredThreshold()
        mu = self.getMean()
        sigma = self.getStddev()

        return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)

Finally it could be used as follows:

df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])

normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)
model  = Pipeline(stages=[normal_deviation]).fit(df)

model.transform(df).show()
## +---+----+----------+
## | id|   x|prediction|
## +---+----+----------+
## |  1| 2.0|     false|
## |  2| 3.0|     false|
## |  3| 0.0|     false|
## |  4|99.0|      true|
## +---+----+----------+
查看更多
登录 后发表回答