Hexbin plot in PairGrid with Seaborn

2019-05-23 02:19发布

问题:

I am trying to get a hexbin plot in a Seaborn Grid. I have the following code,

# Works in Jupyter with Python 2 Kernel.
%matplotlib inline

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

tips = sns.load_dataset("tips")

# Borrowed from http://stackoverflow.com/a/31385996/4099925
def hexbin(x, y, color, **kwargs):
    cmap = sns.light_palette(color, as_cmap=True)
    plt.hexbin(x, y, gridsize=15, cmap=cmap, extent=[min(x), max(x), min(y), max(y)], **kwargs)

g = sns.PairGrid(tips, hue='sex')
g.map_diag(plt.hist)
g.map_lower(sns.stripplot, jitter=True, alpha=0.5)
g.map_upper(hexbin)

However, that gives me the following image,

How can I fix the hexbin plots in such a way that they cover the entire surface of the graph and not just a subset of the shown plot area?

回答1:

There are (at least) three problems with what you are trying to do here.

  1. stripplot is for data where at least one axis is categorical. This is not true in this case. Seaborn guesses that the x axis is the categorical one which messes up the x axes of your subplots. From the docs for stripplot:

    Draw a scatterplot where one variable is categorical.

    In my suggested code below I have changed it to a simple scatter plot.

  2. Drawing two hexbin-plots on top of eachother will only show the latter one. I added some alpha=0.5 to the hexbin arguments, but the result is far from pretty.

  3. The extent parameter in your code adjusted the hexbin plot to x and y of each sex one at a time. But both of the hexbin plots need to be equal in size so they should use min/max of an entire series over both sexes. To achieve this I passed in the minimum and maximum values for all series to the hexbin function which can then pick and use the relevant ones.

Here is what I came up with:

# Works in Jupyter with Python 2 Kernel.
%matplotlib inline

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

tips = sns.load_dataset("tips")

# Borrowed from http://stackoverflow.com/a/31385996/4099925
def hexbin(x, y, color, max_series=None, min_series=None, **kwargs):
    cmap = sns.light_palette(color, as_cmap=True)
    ax = plt.gca()
    xmin, xmax = min_series[x.name], max_series[x.name]
    ymin, ymax = min_series[y.name], max_series[y.name]
    plt.hexbin(x, y, gridsize=15, cmap=cmap, extent=[xmin, xmax, ymin, ymax], **kwargs)

g = sns.PairGrid(tips, hue='sex')
g.map_diag(plt.hist)
g.map_lower(plt.scatter, alpha=0.5)
g.map_upper(hexbin, min_series=tips.min(), max_series=tips.max(), alpha=0.5)

And here is the result: