Correlation matrix using pandas

2019-01-29 16:23发布

问题:

I have a data set with huge number of features, so analysing the correlation matrix has become very difficult. I want to plot a correlation matrix which we get using dataframe.corr() function from pandas library. Is there any built-in function provided by the pandas library to plot this matrix?

回答1:

You can use pyplot.matshow() from matplotlib:

import matplotlib.pyplot as plt

plt.matshow(dataframe.corr())


回答2:

Try this function, which also displays variable names for the correlation matrix:

def plot_corr(df,size=10):
    '''Function plots a graphical correlation matrix for each pair of columns in the dataframe.

    Input:
        df: pandas DataFrame
        size: vertical and horizontal size of the plot'''

    corr = df.corr()
    fig, ax = plt.subplots(figsize=(size, size))
    ax.matshow(corr)
    plt.xticks(range(len(corr.columns)), corr.columns);
    plt.yticks(range(len(corr.columns)), corr.columns);


回答3:

Seaborn's heatmap version:

import seaborn as sns
corr = dataframe.corr()
sns.heatmap(corr, 
            xticklabels=corr.columns.values,
            yticklabels=corr.columns.values)


回答4:

You can observe the relation between features either by drawing a heat map from seaborn or scatter matrix from pandas.

Scatter Matrix:

pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');

If you want to visualize each feature's skewness as well - use seaborn pairplots.

sns.pairplot(dataframe)

Sns Heatmap:

import seaborn as sns

f, ax = pl.subplots(figsize=(10, 8))
corr = dataframe.corr()
sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True),
            square=True, ax=ax)

The output will be a correlation map of the features. i.e. see the below example.

The correlation between grocery and detergents is high. Similarly:

Pdoducts With High Correlation:
  1. Grocery and Detergents.
Products With Medium Correlation:
  1. Milk and Grocery
  2. Milk and Detergents_Paper
Products With Low Correlation:
  1. Milk and Deli
  2. Frozen and Fresh.
  3. Frozen and Deli.

From Pairplots: You can observe same set of relations from pairplots or scatter matrix. But from these we can say that whether the data is normally distributed or not.

Note: The above is same graph taken from the data, which is used to draw heatmap.



回答5:

If your main goal is to visualize the correlation matrix, rather than creating a plot per se, the convenient pandas styling options is a viable built-in solution:

import pandas as pd
import numpy as np

rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
corr = df.corr()
corr.style.background_gradient()

Note that this needs to be in a backend that supports rendering HTML, such as the JupyterLab Notebook. (The automatic light text on dark backgrounds is from an existing PR and not the latest released version, pandas 0.23).


Styling

You can easily limit the digit precision:

corr.style.background_gradient().set_precision(2)

Or get rid of the digits altogether if you prefer the matrix without annotations:

corr.style.background_gradient().set_properties(**{'font-size': '0pt'})

The styling documentation also includes instructions of more advanced styles, such as how to change the display of the cell the mouse pointer is hovering over. To save the output you could return the HTML by appending the render() method and then write it to a file (or just take a screenshot for less formal purposes).


Time comparison

In my testing, style.background_gradient() was 4x faster than plt.matshow() and 120x faster than sns.heatmap() with a 10x10 matrix. Unfortunately it doesn't scale as well as plt.matshow(): the two take about the same time for a 100x100 matrix, and plt.matshow() is 10x faster for a 1000x1000 matrix.



回答6:

You can use imshow() method from matplotlib

import pandas as pd
import matplotlib.pyplot as plt
matplotlib.style.use('ggplot')

plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest')
plt.colorbar()
tick_marks = [i for i in range(len(X.columns))]
plt.xticks(tick_marks, X.columns, rotation='vertical')
plt.yticks(tick_marks, X.columns)
plt.show()