Scikit - Combining scale and grid search

2020-06-03 04:28发布

I am new to scikit, and have 2 slight issues to combine a data scale and grid search.

  1. Efficient scaler

Considering a cross validation using Kfolds, I would like that each time we train the model on the K-1 folds, the data scaler (using preprocessing.StandardScaler() for instance) is fit only on the K-1 folds and then apply to the remaining fold.

My impression is that the following code, will fit the scaler on the entire dataset, and therefore I would like to modify it to behave as described previsouly:

classifier = svm.SVC(C=1)    
clf = make_pipeline(preprocessing.StandardScaler(), classifier)
tuned_parameters = [{'C': [1, 10, 100, 1000]}]
my_grid_search = GridSearchCV(clf, tuned_parameters, cv=5)
  1. Retrieve inner scaler fitting

When refit=True, "after" the Grid Search, the model is refit (using the best estimator) on the entire dataset, my understanding is that the pipeline will be used again, and therefore the scaler will be fit on the entire dataset. Ideally I would like to reuse that fit to scale my 'test' dataset. Is there a way to retrieve it directly from the GridSearchCV?

1条回答
别忘想泡老子
2楼-- · 2020-06-03 05:29
  1. GridSearchCV knows nothing about the Pipeline object; it assumes that the provided estimator is atomic in the sense that it cannot choose only some particular stage (StandartScaler for example) and fit different stages on different data. All GridSearchCV does - calls fit(X, y) method on the provided estimator, where X,y - some splits of data. Thus it fits all stages on same splits.
  2. Try this:

    best_pipeline = my_grid_search.best_estimator_ best_scaler = best_pipeline["standartscaler"]

  3. In case when you wrap your transformers/estimators into Pipeline - you have to add a prefix to a name of each parameter, e.g: tuned_parameters = [{'svc__C': [1, 10, 100, 1000]}], look at these examples for more details Concatenating multiple feature extraction methods, Pipelining: chaining a PCA and a logistic regression

Anyway read this, it may help you GridSearchCV

查看更多
登录 后发表回答