Conditional cumulative mean for each group in R

2019-04-16 08:25发布

问题:

I have a data set that looks like this:

id   a   b
1    AA  2
1    AB  5
1    AA  1
2    AB  2
2    AB  4
3    AB  4
3    AB  3
3    AA  1

I need to calculate the cumulative mean for each record within each group and excluding the case where a == 'AA', So sample output should be:

id   a   b  mean
1    AA  2   -
1    AB  5   5
1    AA  1   5
2    AB  2   2
2    AB  4   (4+2)/2
3    AB  4   4
3    AB  3   (4+3)/2
3    AA  1   (4+3)/2
3    AA  4   (4+3)/2

I tried to achieve it using dplyr and cummean by getting an error.

df <- df %>%
       group_by(id) %>%
       mutate(mean = cummean(b[a != 'AA']))

Error: incompatible size (123), expecting 147 (the group size) or 1

Can you suggest a better way to achieve the same in R ?

回答1:

The trick here is to reconstruct the cummean by dividing the adjusted cumsum by the adjusted count. As a one-liner:

df %>% group_by(id) %>% mutate(cumsum(b * (a != 'AA')) / cumsum(a != 'AA'))

We can make this a little nicer (the "multiply by a!='AA' - magic!" is the ugliness in my mind) by taking out the a != 'AA' as a column

df %>%
    group_by(id) %>%
    mutate(relevance = 0+(a!='AA'), 
           mean = cumsum(relevance * b) / cumsum(relevance))


回答2:

There may be an easier approach. Here, we group by 'id'. Create a new column 'Mean' by first converting the elements in 'b' that corresponds to 'AA' in 'a' to NA (b*NA^(a=='AA')). NA^(a=='AA') gives an output of NA for 'AA' in 'a' and 1 for all other values. So, when we multiply by 'b', it replaces the 1 with the values in 'b' while NA remains as such. We use na.aggregate to replace the 'NA' with the mean of non-NA elements in each group, then wrap with cummean to get the cumulative mean. If the first value in each group for 'a' is 'AA', we can get NA for that by multiplying with NA^(row_number()==1 & a=='AA').

library(zoo)
library(dplyr)
df %>% 
   group_by(id) %>% 
   mutate(Mean= cummean(na.aggregate(b*NA^(a=='AA')))*
                 NA^(row_number()==1 & a=='AA'))
# Source: local data frame [9 x 4]
#Groups: id [3]

#      id     a     b  Mean
#   (int) (chr) (int) (dbl)
#1     1    AA     2    NA
#2     1    AB     5   5.0
#3     1    AA     1   5.0
#4     2    AB     2   2.0
#5     2    AB     4   3.0
#6     3    AB     4   4.0
#7     3    AB     3   3.5
#8     3    AA     1   3.5
#9     3    AA     4   3.5

data

df <- structure(list(id = c(1L, 1L, 1L, 2L, 2L, 3L, 3L, 3L, 3L), 
a = c("AA", 
"AB", "AA", "AB", "AB", "AB", "AB", "AA", "AA"), b = c(2L, 5L, 
1L, 2L, 4L, 4L, 3L, 1L, 4L)), .Names = c("id", "a", "b"),
class = "data.frame", row.names = c(NA, -9L))


标签: r dplyr mean