sklearn: get feature names after L1-based feature

2020-06-04 06:54发布

This question and answer demonstrate that when feature selection is performed using one of scikit-learn's dedicated feature selection routines, then the names of the selected features can be retrieved as follows:

np.asarray(vectorizer.get_feature_names())[featureSelector.get_support()]

For example, in the above code, featureSelector might be an instance of sklearn.feature_selection.SelectKBest or sklearn.feature_selection.SelectPercentile, since these classes implement the get_support method which returns a boolean mask or integer indices of the selected features.

When one performs feature selection via linear models penalized with the L1 norm, it's unclear how to accomplish this. sklearn.svm.LinearSVC has no get_support method and the documentation doesn't make clear how to retrieve the feature indices after using its transform method to eliminate features from a collection of samples. Am I missing something here?

2条回答
Juvenile、少年°
2楼-- · 2020-06-04 07:20

For sparse estimators you can generally find the support by checking where the non-zero entries are in the coefficients vector (provided the coefficients vector exists, which is the case for e.g. linear models)

support = np.flatnonzero(estimator.coef_)

For your LinearSVC with l1 penalty it would accordingly be

from sklearn.svm import LinearSVC
svc = LinearSVC(C=1., penalty='l1', dual=False)
svc.fit(X, y)
selected_feature_names = np.asarray(vectorizer.get_feature_names())[np.flatnonzero(svc.coef_)]
查看更多
放我归山
3楼-- · 2020-06-04 07:21

I've been using sklearn 15.2, and according to LinearSVC documentation , coef_ is an array, shape = [n_features] if n_classes == 2 else [n_classes, n_features]. So first, np.flatnonzero doesn't work for multi-class. You'll have index out of range error. Second, it should be np.where(svc.coef_ != 0)[1] instead of np.where(svc.coef_ != 0)[0] . 0 is index of classes, not features. I ended up with using np.asarray(vectorizer.get_feature_names())[list(set(np.where(svc.coef_ != 0)[1]))]

查看更多
登录 后发表回答