scikit-learn: applying an arbitary function as par

2020-07-17 06:03发布

问题:

I've just discovered the Pipeline feature of scikit-learn, and I find it very useful for testing different combinations of preprocessing steps before training my model.

A pipeline is a chain of objects that implement the fit and transform methods. Now, if I wanted to add a new preprocessing step, I used to write a class that inherits from sklearn.base.estimator. However, I'm thinking that there must be a simpler method. Do I really need to wrap every function I want to apply in an estimator class?

Example:

class Categorizer(sklearn.base.BaseEstimator):
    """
    Converts given columns into pandas dtype 'category'.
    """

    def __init__(self, columns):
        self.columns = columns

    def fit(self, X, y):
        return self


    def transform(self, X):
        for column in self.columns:
            X[column] = X[column].astype("category")
        return X

回答1:

For a general solution (working for many other use cases, not just transformers, but also simple models etc.), you can write your own decorator if you have state-free functions (which do not implement fit), for example by doing:

class TransformerWrapper(sklearn.base.BaseEstimator):

    def __init__(self, func):
        self._func = func

    def fit(self, *args, **kwargs):
        return self

    def transform(self, X, *args, **kwargs):
        return self._func(X, *args, **kwargs)

and now you can do

@TransformerWrapper
def foo(x):
  return x*2

which is equivalent of doing

def foo(x):
  return x*2

foo = TransformerWrapper(foo)

which is what sklearn.preprocessing.FunctionTransformer is doing under the hood.

Personally I find decorating simpler, since you have a nice separation of your preprocessors from the rest of the code, but it is up to you which path to follow.

In fact you should be able to decorate with sklearn function by

from sklearn.preprocessing import FunctionTransformer

@FunctionTransformer
def foo(x):
  return x*2

too.



回答2:

The sklearn.preprocessing.FunctionTransformer class can be used to instantiate a scikit-learn transformer (which can be used e.g. in a pipeline) from a user provided function.



回答3:

I think it's worth to mention that sklearn.preprocessing.FunctionTransformer(..., validate=True) has a validate=False parameter:

validate : bool, optional default=True

Indicate that the input X array should be checked before calling func. If validate is false, there will be no input validation. If it is true, then X will be converted to a 2-dimensional NumPy array or sparse matrix. If this conversion is not possible or X contains NaN or infinity, an exception is raised.

So if you are going to pass non-numerical features to FunctionTransformer make sure that you explicitly set validate=False, otherwise it'll fail with the following exception:

ValueError: could not convert string to float: 'your non-numerical value'