This question already has an answer here:
- How to plot multiple Seaborn Jointplot in Subplot 2 answers
Some minimal code of what I'm working on. Some parameters may seem redundant, but I didn't bother removing all of them.
import matplotlib
import matplotlib.gridspec as gridspec
matplotlib.use("macosx")
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
def plot_overlaid_2d_hist(data,
plot_axis_x,
plot_axis_y,
plot_axis_x_lab,
plot_axis_y_lab,
group_by = "group_name"):
# don't mind this for now
df = data
# Figure aspect
w, h = plt.figaspect(1)
fig = plt.figure(figsize = (w, h))
# Count the number of groups to make plots for
n_groups = len(df.groupby(group_by))
gs = gridspec.GridSpec(nrows = n_groups, ncols = 1)
subplot_id = 0
# Reshape data to make it work
for name, group in df.groupby(group_by, sort = False):
# Initialize subplot
fig.add_subplot(gs[subplot_id, 0])
# Check if we get subplots with pyplot
if subplot_id == 0:
col = "red"
else:
col = "blue"
plt.plot(x, y, color = col)
# instantiate JointGrid
# g = sns.JointGrid(group[plot_axis_x],
# group[plot_axis_y],
# data = group,
# space = 0,
# xlim = (0, 1.2),
# ylim = (0, 1))
#
# # Fix labels
# g = g.set_axis_labels(xlabel = str(plot_axis_x_lab),
# ylabel = str(plot_axis_y_lab))
#
# # center scatter plot on top
# g = g.plot_joint(plt.scatter,
# s = 0.5,
# alpha = 1,
# linewidth = 1)
#
# # marginals plot
# g = g.plot_marginals(sns.distplot,
# kde = True,
# kde_kws = dict(linewidth = 2,
# alpha = 1,
# bw = "Scott"),
# hist_kws = dict(alpha = 1))
# Next plot in row +1
subplot_id += 1
# Output
plt.tight_layout() # Attempts to fix alignment of subplot layout and axis titles
plt.show()
# quick data to check if the plots end up where they should
x = [0.5, 0.5, 0.4, 0.4]
y = [0.6, 0.4, 0.3, 0.4]
grp = ["a", "a", "b", "b"]
df = pd.DataFrame({"x":x,
"y":y,
"grp": grp})
plot_overlaid_2d_hist(data = df,
group_by = "grp",
plot_axis_x_lab = "x",
plot_axis_y_lab = "x",
plot_axis_y = "x",
plot_axis_x = "x")
Running the code with all seaborn plots (g) commented out shows that it works fine for native pyplot, but when I add in the multi-part seaborn plots, they show up in separate figures. What I would like is to have each 2D-histogram-with-marginals-and-scatter populate their own gridspec row/column.