I have a data set that looks like this:
id a b
1 AA 2
1 AB 5
1 AA 1
2 AB 2
2 AB 4
3 AB 4
3 AB 3
3 AA 1
I need to calculate the cumulative mean for each record within each group and excluding the case where a == 'AA'
, So sample output should be:
id a b mean
1 AA 2 -
1 AB 5 5
1 AA 1 5
2 AB 2 2
2 AB 4 (4+2)/2
3 AB 4 4
3 AB 3 (4+3)/2
3 AA 1 (4+3)/2
3 AA 4 (4+3)/2
I tried to achieve it using dplyr and cummean by getting an error.
df <- df %>%
group_by(id) %>%
mutate(mean = cummean(b[a != 'AA']))
Error: incompatible size (123), expecting 147 (the group size) or 1
Can you suggest a better way to achieve the same in R ?
The trick here is to reconstruct the cummean
by dividing the adjusted cumsum
by the adjusted count. As a one-liner:
df %>% group_by(id) %>% mutate(cumsum(b * (a != 'AA')) / cumsum(a != 'AA'))
We can make this a little nicer (the "multiply by a!='AA'
- magic!" is the ugliness in my mind) by taking out the a != 'AA'
as a column
df %>%
group_by(id) %>%
mutate(relevance = 0+(a!='AA'),
mean = cumsum(relevance * b) / cumsum(relevance))
There may be an easier approach. Here, we group by 'id'. Create a new column 'Mean' by first converting the elements in 'b' that corresponds to 'AA' in 'a' to NA
(b*NA^(a=='AA')
). NA^(a=='AA')
gives an output of NA
for 'AA' in 'a' and 1 for all other values. So, when we multiply by 'b', it replaces the 1 with the values in 'b' while NA remains as such. We use na.aggregate
to replace the 'NA' with the mean
of non-NA elements in each group, then wrap with cummean
to get the cumulative mean. If the first value in each group for 'a' is 'AA', we can get NA
for that by multiplying with NA^(row_number()==1 & a=='AA')
.
library(zoo)
library(dplyr)
df %>%
group_by(id) %>%
mutate(Mean= cummean(na.aggregate(b*NA^(a=='AA')))*
NA^(row_number()==1 & a=='AA'))
# Source: local data frame [9 x 4]
#Groups: id [3]
# id a b Mean
# (int) (chr) (int) (dbl)
#1 1 AA 2 NA
#2 1 AB 5 5.0
#3 1 AA 1 5.0
#4 2 AB 2 2.0
#5 2 AB 4 3.0
#6 3 AB 4 4.0
#7 3 AB 3 3.5
#8 3 AA 1 3.5
#9 3 AA 4 3.5
data
df <- structure(list(id = c(1L, 1L, 1L, 2L, 2L, 3L, 3L, 3L, 3L),
a = c("AA",
"AB", "AA", "AB", "AB", "AB", "AB", "AA", "AA"), b = c(2L, 5L,
1L, 2L, 4L, 4L, 3L, 1L, 4L)), .Names = c("id", "a", "b"),
class = "data.frame", row.names = c(NA, -9L))