Scaling of target causes Scikit-learn SVM regressi

2019-02-07 01:51发布

When training a SVM regression it is usually advisable to scale the input features before training.

But how about scaling of the targets? Usually this is not considered necessary, and I do not see a good reason why it should be necessary.

However in the scikit-learn example for SVM regression from: http://scikit-learn.org/stable/auto_examples/svm/plot_svm_regression.html

By just introducing the line y=y/1000 before training, the prediction will break down to a constant value. Scaling the target variable before training would solve the problem, but I do not understand why it is necessary.

What causes this problem?

import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt

# Generate sample data
X = np.sort(5 * np.random.rand(40, 1), axis=0)
y = np.sin(X).ravel()

# Add noise to targets
y[::5] += 3 * (0.5 - np.random.rand(8))

# Added line: this will make the prediction break down
y=y/1000

# Fit regression model
svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1)
svr_lin = SVR(kernel='linear', C=1e3)
svr_poly = SVR(kernel='poly', C=1e3, degree=2)
y_rbf = svr_rbf.fit(X, y).predict(X)
y_lin = svr_lin.fit(X, y).predict(X)
y_poly = svr_poly.fit(X, y).predict(X)

# look at the results
plt.scatter(X, y, c='k', label='data')
plt.hold('on')
plt.plot(X, y_rbf, c='g', label='RBF model')
plt.plot(X, y_lin, c='r', label='Linear model')
plt.plot(X, y_poly, c='b', label='Polynomial model')
plt.xlabel('data')
plt.ylabel('target')
plt.title('Support Vector Regression')
plt.legend()
plt.show()

1条回答
劫难
2楼-- · 2019-02-07 02:43

Support vector regression uses a loss function that is only positive if the difference between the predicted value and the target exceeds some threshold. Below the threshold, the prediction is considered "good enough" and the loss is zero. When you scale down the targets, the SVM learner can get away with returning a flat model, because it no longer incurs any loss.

The threshold parameter is called epsilon in sklearn.svm.SVR; set it to a lower value for smaller targets. The math behind this is explained here.

查看更多
登录 后发表回答