Find nearest matches for each row and sum based on

2020-02-26 11:39发布

问题:

Consider the following data.table of events:

library(data.table)
breaks <- data.table(id = 1:8,
                     Channel = c("NP1", "NP1", "NP2", "NP2", "NP3", "NP3", "AT4", "AT4"),
                     Time = c(1000, 1100, 975, 1075, 1010, 1080, 1000, 1050),
                     Day = c(1, 1, 1, 1, 1, 1, 1, 1),
                     ZA = c(15, 12, 4, 2, 1, 2, 23, 18),
                     stringsAsFactors = F)

breaks
   id Channel Time Day ZA
1:  1     NP1 1000   1 15
2:  2     NP1 1100   1 12
3:  3     NP2  975   1  4
4:  4     NP2 1075   1  2
5:  5     NP3 1010   1  1
6:  6     NP3 1080   1  2
7:  7     AT4 1000   1 23
8:  8     AT4 1050   1 18

For each unique event in breaks I want to find the nearest events in all other channels using the Time variable where Day == Day and then sum the values of ZA for these events.

This is the result I want to achieve:

   id Channel Time Day ZA Sum
1:  1     NP1 1000   1 15  28
2:  2     NP1 1100   1 12  22
3:  3     NP2  975   1  4  39
4:  4     NP2 1075   1  2  32
5:  5     NP3 1010   1  1  42
6:  6     NP3 1080   1  2  32
7:  7     AT4 1000   1 23  20
8:  8     AT4 1050   1 18  19

So for the first row the channel is NP1. The closes events in all other channels to Time = 1000 are rows 3, 5 and 7. 4+1+23 = 28

I got this to work using data.table with the following code:

breaks[breaks[, c("Day", "Time", "Channel", "ZA")], on = "Day", allow.cartesian = TRUE][
  Channel != i.Channel][
    order(id)][
      , delta := abs(Time - i.Time)][
        , .SD[delta == min(delta)], by = .(Channel, Time, Day, i.Channel)][
          , unique(.SD, by = c("id", "i.Channel"))][
            , .(Sum = sum(i.ZA)), by = .(id, Channel, Time, Day, ZA)]

However, this creates a dataset with 64 rows in the first step and I'd like to do this with a dataset of more than a million rows.

Can anyone help me find a more efficient way of doing this?

Edit:

I tried out the solutions of G. Grothendieck (sqldf), eddi (data.table) and MarkusN (dplyr) once on the full dataset of 1.4 million rows with 39 different channels. The dataset was in-memory.

sqldf:      54 minutes
data.table: 11 hours
dplyr:      29 hours

回答1:

In the inner select self-join each row in breaks to those rows on the same day and different channel and then among all the joined rows to a particular original row keep only the joined row having the minmum absolute time difference. In the outer select sum the ZA from the other Channel within id giving the result.

Note that we are assuming the default SQLite backend to sqldf here and are using a feature specific to that database, namely, that if min is used in a select then the other values specified in that select will also be populated from the minimizing row.

By default it will use an in-memory database which would be best if it fits but if you specify dbname = tempfile() as an argument to sqldf it will use a file as an out-of-memory database instead. It would also be possible to add one or more indexes which may or may not speed it up. See the sqldf github home page for more examples.

library(sqldf)

sqldf("select id, Channel, Time, Day, ZA, sum(bZA) Sum
 from (
   select a.*, b.ZA bZA, min(abs(a.Time - b.Time))
   from breaks a join breaks b on a.Day = b.Day and a.Channel != b.Channel
   group by a.id, b.Channel)
 group by id")

giving:

  id Channel Time Day ZA Sum
1  1     NP1 1000   1 15  28
2  2     NP1 1100   1 12  22
3  3     NP2  975   1  4  39
4  4     NP2 1075   1  2  32
5  5     NP3 1010   1  1  42
6  6     NP3 1080   1  2  32
7  7     AT4 1000   1 23  20
8  8     AT4 1050   1 18  19

This is slightly faster than the data.table code in the question on a problem of this size but for larger problems the comparison would have to be redone.

Also, it may be able to handle larger size due to not having to materialize the intermedidate results (depending on the query optimizer) and the possibility of handling it out of memory (if need be).

library(data.table)
library(dplyr)
library(sqldf)
library(rbenchmark)

benchmark(sqldf = 
sqldf("select id, Channel, Time, Day, ZA, sum(bZA) Sum
 from (
   select a.*, b.ZA bZA, min(abs(a.Time - b.Time))
   from breaks a join breaks b on a.Day = b.Day and a.Channel != b.Channel
   group by a.id, b.Channel)
 group by id"),

data.table = breaks[breaks[, c("Day", "Time", "Channel", "ZA")], on = "Day",
     allow.cartesian = TRUE][
  Channel != i.Channel][
    order(id)][
      , delta := abs(Time - i.Time)][
        , .SD[delta == min(delta)], by = .(Channel, Time, Day, i.Channel)][
          , unique(.SD, by = c("id", "i.Channel"))][
            , .(Sum = sum(i.ZA)), by = .(id, Channel, Time, Day, ZA)],

dplyr = { breaks %>% 
  inner_join(breaks, by=c("Day"), suffix=c("",".y")) %>%
  filter(Channel != Channel.y) %>%
  group_by(id, Channel, Time, Day, ZA, Channel.y) %>%
  arrange(abs(Time - Time.y)) %>%
  filter(row_number()==1) %>%
  group_by(id, Channel, Time, Day, ZA) %>%
  summarise(Sum=sum(ZA.y)) %>%                           
  ungroup() %>% 
  select(id:Sum) },

order = "elapsed")[1:4]

giving:

        test replications elapsed relative
1      sqldf          100    3.38    1.000
2 data.table          100    4.05    1.198
3      dplyr          100    9.23    2.731


回答2:

I'm not sure about the speed of this (probably slow), but it will be very conservative memory-wise:

Channels = breaks[, unique(Channel)]
breaks[, Sum := breaks[breaks[row,
                              .(Day, Channel = setdiff(Channels, Channel), Time)],
                       on = .(Day, Channel, Time), roll = 'nearest',
                       sum(ZA)]
       , by = .(row = 1:nrow(breaks))]

It'll probably help the speed to setkey(breaks, Day, Channel, Time) instead of using on.



回答3:

Here's a solution using dplyr and a self-join:

library(dplyr)
breaks %>% 
  inner_join(breaks, by=c("Day"), suffix=c("",".y")) %>%  # self-join
  filter(Channel != Channel.y) %>%                        # ignore events of same channel
  group_by(id, Channel, Time, Day, ZA, Channel.y) %>%     # build group for every event
  arrange(abs(Time - Time.y)) %>%                         # sort by minimal time-diff
  filter(row_number()==1) %>%                             # keep just row with minimal time-diff
  group_by(id, Channel, Time, Day, ZA) %>%                # group by all columns of original event
  summarise(Sum=sum(ZA.y)) %>%                            # sum ZA of other channels
  ungroup() %>% 
  select(id:Sum)

Maybe I have to be more specific about my answer. Unlike data.table dplyr has the abilty to translate code into sql. So if your data is stored in a Database you can connect directly to the table contaiing your data. All (most of the) dpylr code is evaluated in your DBMS. Since performing joins is a key task of every DBMS you don't have to worry about performance.

However, if your data is imported into R and you worry about RAM-limits you have to iterate over every row of the dataframe. This can be accomlished with dplyr as well:

library(dplyr)
breaks %>% 
rowwise() %>% 
do({
  row = as_data_frame(.)
  df =
    breaks %>%
    filter(Day == row$Day & Channel != row$Channel) %>% 
    mutate(time_diff = abs(Time-row$Time)) %>% 
    group_by(Channel) %>% 
    arrange(abs(Time-row$Time), .by_group=TRUE) %>% 
    filter(row_number()==1) %>% 
    ungroup() %>% summarise(sum(ZA))

  row %>% mutate(sumZA = df[[1]])
})


回答4:

Came across this and saw the timings in the OP edit. Hence, proposing a possible Rcpp approach:

library(Rcpp)
#library(inline)
nearsum <- cppFunction('
NumericVector closestSum(NumericVector cid, NumericVector Time, NumericVector ZA) {
    int d, mintime, mintimeZA, prevChannel = 0, nextChannel = 0;
    int sz = cid.size();
    NumericVector sumvec(sz);

    for (int r = 0; r < sz; r++) {
        sumvec[r] = 0;
        mintime = 10000;
        //Rcpp::Rcout << "Beginning row = " << r << std::endl;

        for (int i = 0; i < sz; i++) {
            if (cid[r] != cid[i]) {
                //Rcpp::Rcout << "Current idx = " << i << std::endl;

                //handle boundary conditions
                if (i == 0) {
                    prevChannel = 0;    
                } else {
                    prevChannel = cid[i-1];
                }

                if (i == sz - 1) {
                    nextChannel = 0;    
                } else {
                    nextChannel = cid[i+1];
                }

                //calculate time difference
                d = abs(Time[i] - Time[r]);

                if (cid[i] != prevChannel) {
                    ///this is a new channel
                    mintime = d;
                    mintimeZA = ZA[i];
                } else {
                    if (d < mintime) {
                        //this is a new min in time diff
                        mintime = d;
                        mintimeZA = ZA[i];
                    }
                }

                //Rcpp::Rcout << "Time difference = " << d << std::endl;
                //Rcpp::Rcout << "ZA for current min time gap = " << mintimeZA << std::endl;

                if (cid[i] != nextChannel) {
                    //this is the last data point for this channel
                    mintime = 10000;
                    sumvec[r] += mintimeZA;
                    //Rcpp::Rcout << "Final sum for current row = " << sumvec[r] << std::endl;
                }
            }
        }
    }
    return sumvec;
}
')

Calling the cpp function:

library(data.table)
setorder(breaks, id, Channel, Day, Time)
breaks[, ChannelID := .GRP, by=Channel]
breaks[, Sum := nearsum(ChannelID, Time, ZA), by=.(Day)]

output:

   id Channel Time Day ZA ChannelID Sum
1:  1     NP1 1000   1 15         1  28
2:  2     NP1 1100   1 12         1  22
3:  3     NP2  975   1  4         2  39
4:  4     NP2 1075   1  2         2  32
5:  5     NP3 1010   1  1         3  42
6:  6     NP3 1080   1  2         3  32
7:  7     AT4 1000   1 23         4  20
8:  8     AT4 1050   1 18         4  19

timing code:

#create a larger dataset
largeBreaks <- rbindlist(lapply(1:1e5, function(n) copy(breaks)[, Day := n]))
setorder(largeBreaks, Day, Channel, Time)
largeBreaks[, id := .I]

library(sqldf)
mtd0 <- function() {
    sqldf("select id, Channel, Time, Day, ZA, sum(bZA) Sum
     from (
       select a.*, b.ZA bZA, min(abs(a.Time - b.Time))
       from largeBreaks a join largeBreaks b on a.Day = b.Day and a.Channel != b.Channel
       group by a.id, b.Channel)
     group by id")
}

mtd1 <- function() {
    setorder(largeBreaks, Day, Channel, Time)
    largeBreaks[, ChannelID := .GRP, by=Channel]
    largeBreaks[, Sum := nearsum(ChannelID, Time, ZA), by=.(Day)]
}

library(microbenchmark)
microbenchmark(mtd0(), mtd1(), times=3L)

timings [need to add around 5sec (at least on my machine) to compile the cpp function]:

Unit: milliseconds
   expr        min         lq       mean    median         uq        max neval
 mtd0() 10449.6696 10505.7669 10661.7734 10561.864 10767.8252 10973.7863     3
 mtd1()   365.4157   371.2594   386.6866   377.103   397.3221   417.5412     3