Comparison of R and scikit-learn for a classificat

2019-04-10 15:56发布

I am doing a Logistic Regression described in the book 'An Introduction to Statistical Learning with Applications in R' by James, Witten, Hastie, Tibshirani (2013).

More specifically, I am fitting the binary classification model to the 'Wage' dataset from the R package 'ISLR' described in §7.8.1.

Predictor 'age' (transformed to polynomial, degree 4) is fitted against the binary classification wage>250. Then the age is plotted against the predicted probabilities of the 'True' value.

The model in R is fit as follows:

fit=glm(I(wage>250)~poly(age,4),data=Wage, family=binomial)

agelims=range(age) 
age.grid=seq(from=agelims[1],to=agelims[2])
preds=predict(fit,newdata=list(age=age.grid),se=T)
pfit=exp(preds$fit)/(1+exp(preds$fit))

Complete code (author's site): http://www-bcf.usc.edu/~gareth/ISL/Chapter%207%20Lab.txt
The corresponding plot from the book: http://www-bcf.usc.edu/~gareth/ISL/Chapter7/7.1.pdf (right)

I tried to fit a model to the same data in scikit-learn:

poly = PolynomialFeatures(4)
X = poly.fit_transform(df.age.reshape(-1,1))
y = (df.wage > 250).map({False:0, True:1}).as_matrix()
clf = LogisticRegression()
clf.fit(X,y)

X_test = poly.fit_transform(np.arange(df.age.min(), df.age.max()).reshape(-1,1))
prob = clf.predict_proba(X_test)

I then plotted probabilities of the 'True' values against the age range. But the result/plot looks quite different. (Not talking about the CI bands or rugplot, just the probability plot.) Am I missing something here?

1条回答
▲ chillily
2楼-- · 2019-04-10 16:51

After some more reading I understand that scikit-learn implements a regularized logistic regression model, whereas glm in R is not regularized. Statsmodels' GLM implementation (python) is unregularized and gives identical results as in R.

http://statsmodels.sourceforge.net/stable/generated/statsmodels.genmod.generalized_linear_model.GLM.html#statsmodels.genmod.generalized_linear_model.GLM

The R package LiblineaR is similar to scikit-learn's logistic regression (when using 'liblinear' solver).

https://cran.r-project.org/web/packages/LiblineaR/

查看更多
登录 后发表回答