I am looking for a way to graph grid_scores_ from GridSearchCV in sklearn. In this example I am trying to grid search for best gamma and C parameters for an SVR algorithm. My code looks as follows:
C_range = 10.0 ** np.arange(-4, 4)
gamma_range = 10.0 ** np.arange(-4, 4)
param_grid = dict(gamma=gamma_range.tolist(), C=C_range.tolist())
grid = GridSearchCV(SVR(kernel='rbf', gamma=0.1),param_grid, cv=5)
grid.fit(X_train,y_train)
print(grid.grid_scores_)
After I run the code and print the grid scores I get the following outcome:
[mean: -3.28593, std: 1.69134, params: {'gamma': 0.0001, 'C': 0.0001}, mean: -3.29370, std: 1.69346, params: {'gamma': 0.001, 'C': 0.0001}, mean: -3.28933, std: 1.69104, params: {'gamma': 0.01, 'C': 0.0001}, mean: -3.28925, std: 1.69106, params: {'gamma': 0.1, 'C': 0.0001}, mean: -3.28925, std: 1.69106, params: {'gamma': 1.0, 'C': 0.0001}, mean: -3.28925, std: 1.69106, params: {'gamma': 10.0, 'C': 0.0001},etc]
I would like to visualize all the scores (mean values) depending on gamma and C parameters. The graph I am trying to obtain should look as follows:
Where x-axis is gamma, y-axis is mean score (root mean square error in this case), and different lines represent different C values.
This worked for me when I was trying to plot mean scores vs no. of trees in the Random Forest. The reshape() function helps to find out the averages.
Here's fully working code that will produce plots so you can fully visualize the varying of up to 3 parameters using GridSearchCV. This is what you will see when running the code:
For each line plotted, also shown is a standard deviation of what you can expect the Cross Validation Mean Score to do based on the multiple CV's you're running. Enjoy!
The code shown by @sascha is correct. However, the
grid_scores_
attribute will be soon deprecated. It is better to use thecv_results
attribute.It can be implemente in a similar fashion to that of @sascha method:
The above results in the following plot:
I wanted to do something similar (but scalable to a large number of parameters) and here is my solution to generate swarm plots of the output: