Plot multivariate Gaussian contours with ggplot2

2019-05-03 09:27发布

问题:

I'm trying to augment a plot with contours from a 2D Gaussian distribution with known mean and covariance. Ideally I would just have to specify the function and it would be plotted in 2D (like stat_function except for 2 dimensions). I can do it with geom_raster by generating a grid of probabilites. Can I use geom_contour2d somehow instead?

m <- c(.5, -.5)
sigma <- matrix(c(1,.5,.5,1), nrow=2)
data.grid <- expand.grid(s.1 = seq(-3, 3, length.out=200), s.2 = seq(-3, 3, length.out=200))
q.samp <- cbind(data.grid, prob = mvtnorm::dmvnorm(data.grid, mean = m, sigma = sigma))
ggplot(q.samp, aes(x=s.1, y=s.2)) + 
    geom_raster(aes(fill = prob)) +
    coord_fixed(xlim = c(-3, 3), ylim = c(-3, 3), ratio = 1)

回答1:

I would use the ellipse package to construct the contour data directly. This still requires a separate call to construct the data, but is much more efficient (in both space and time) than your solution of constructing an entire grid and then finding the contours.

library(ellipse)
library(plyr)  ## not necessary, but convenient
m <- c(.5, -.5)
sigma <- matrix(c(1,.5,.5,1), nrow=2)
alpha_levels <- seq(0.5,0.95,by=0.05) ## or whatever you want
names(alpha_levels) <- alpha_levels ## to get id column in result
contour_data <- ldply(alpha_levels,ellipse,x=sigma,
      scale=c(1,1),  ## needed for positional matching
      centre=m)

(you could use lapply and rbind from base R; plyr::ldply is just a shortcut)

Now plot:

library(ggplot2)
ggplot(contour_data,aes(x,y,group=.id))+geom_path()


回答2:

I was barking up the wrong tree looking at 2d stuff. You can achieve what I want with geom_contour and by adding a z aesthetic:

m <- c(.5, -.5)
sigma <- matrix(c(1,.5,.5,1), nrow=2)
data.grid <- expand.grid(s.1 = seq(-3, 3, length.out=200), s.2 = seq(-3, 3, length.out=200))
q.samp <- cbind(data.grid, prob = mvtnorm::dmvnorm(data.grid, mean = m, sigma = sigma))
ggplot(q.samp, aes(x=s.1, y=s.2, z=prob)) + 
    geom_contour() +
    coord_fixed(xlim = c(-3, 3), ylim = c(-3, 3), ratio = 1) 



标签: r ggplot2