R - fastest way to get the indices of the max n el

2019-04-20 13:28发布

问题:

This question already has an answer here:

  • How to find the indices of the top 10,000 elements in a symmetric matrix(12k X 12k) in R 4 answers

Suppose I have a huge vector x with 1 million elements, and I would like to find the indices of the maximum 30 elements. I don't particularly care whether the results are sorted among these 30 elements, as long as they are the max 30 out of the entire vector. Using order[x][1:30] seems quite expensive since it has to sort the entire vector. I thought about utilizing the partial option in sort, but sort returns the values and the index.return option is not supported when partial is specified. Is there an efficient way to find the indices without sorting the entire vector?

回答1:

I want to add a hybrid approach using sort's partial argument and which:

whichpart <- function(x, n=30) {
  nx <- length(x)
  p <- nx-n
  xp <- sort(x, partial=p)[p]
  which(x > xp)
}

Some benchmarking:

library("microbenchmark")
library("data.table")
library("compiler")

set.seed(123)
x <- rnorm(1e6)
y <- sample.int(1e6)


whichpart <- function(x, n=30) {
  nx <- length(x)
  p <- nx-n
  xp <- sort(x, partial=p)[p]
  which(x > xp)
}

cpwhichpart <- cmpfun(whichpart)

# using quicksort
quicksort <- function(x, n=30) {
  sort(x, method="quick", decreasing=TRUE, index.return=TRUE)$ix[1:n]
}

cpquicksort <- cmpfun(quicksort)

# @Mariam
whichsort <- function(x, n=30) {
  which(x >= sort(x, decreasing=TRUE)[30], arr.ind=TRUE)
}

cpwhichsort <- cmpfun(whichsort)

# @Ferdinand.kraft
top <- function(x, n=30) {
    result <- numeric()
    for(i in 1:n){
        j <- which.max(x)
        result[i] <- j
        x[j] <- -Inf
    }
    result
}

cptop <- cmpfun(top)

# @Tony Breyal
dtable <- function(x, n=30) {
  dt <- data.table(x=x, x.index=seq.int(x))
  setkey(dt, "x")
  dt$x.index[1:n]
}

cpdtable <- cmpfun(dtable)

# @Roland
roland <- cmpfun(function(x, n=30) {
  y <- rep(-Inf, n)
  for (i in seq_along(x)) {
    if (x[i] > y[1]) {
      y[1] <- x[i]
      y <- y[order(y)]
    }
  }
  y
})

## rnorm
microbenchmark(whichpart(x), cpwhichpart(x),
               quicksort(x), cpquicksort(x),
               whichsort(x), cpwhichsort(x),
               top(x), cptop(x),
               dtable(x), cpdtable(x),
               roland(x), times=10)

# Unit: milliseconds
#            expr        min         lq     median         uq        max neval
#    whichpart(x)   45.63544   46.05638   47.09077   49.68452   51.42065    10
#  cpwhichpart(x)   45.65996   45.77212   47.02808   48.07482   82.20458    10
#    quicksort(x)  100.90936  103.00783  105.17506  109.31784  139.83518    10
#  cpquicksort(x)  100.53958  102.78017  107.64470  138.96630  142.52882    10
#    whichsort(x)  148.86010  151.04350  155.80871  159.47063  184.56697    10
#  cpwhichsort(x)  149.05578  150.21183  151.36918  166.58342  173.87567    10
#          top(x)  146.10757  182.42089  184.53050  191.37293  193.62272    10
#        cptop(x)  155.14354  179.14847  184.52323  196.80644  220.21222    10
#       dtable(x) 1041.32457 1042.54904 1049.26096 1065.40606 1080.89969    10
#     cpdtable(x) 1042.08247 1043.54915 1051.76366 1084.14360 1310.26485    10
#       roland(x)  251.42885  261.47608  273.20838  295.09733  323.96257    10

## integer
microbenchmark(whichpart(y), cpwhichpart(y),
               quicksort(y), cpquicksort(y),
               whichsort(y), cpwhichsort(y),
               top(y), cptop(y),
               dtable(y), cpdtable(y),
               roland(y), times=10)

# Unit: milliseconds
#            expr       min        lq    median        uq       max neval
#    whichpart(y)  11.60703  11.76857  12.03704  12.52871  47.88526    10
#  cpwhichpart(y)  11.62885  11.75006  12.53724  13.88563  46.93677    10
#    quicksort(y)  88.14924  89.47630  92.42414 103.53439 137.44335    10
#  cpquicksort(y)  88.11544  89.15334  92.63420  94.42244 133.78006    10
#    whichsort(y) 122.34675 123.13634 124.91990 127.79134 131.43400    10
#  cpwhichsort(y) 121.85618 122.91653 125.45211 127.14112 158.61535    10
#          top(y) 163.06669 181.19004 211.11557 224.19237 239.63139    10
#        cptop(y) 163.37903 173.55113 209.46770 218.59685 226.81545    10
#       dtable(y) 499.50807 505.45513 514.55338 537.84129 604.86454    10
#     cpdtable(y) 491.70016 498.62664 525.05342 527.14666 580.19429    10
#       roland(y) 235.44664 237.52200 242.87925 268.34080 287.71196    10


identical(sort(quicksort(x)), whichpart(x))
# [1] TRUE

Edit: test @flodel's suggestion

# @flodel
whichpartrev <- function(x, n=30) {
  which(x >= -sort(-x, partial=n)[n])
}

microbenchmark(whichpart(x), whichpartrev(x), times=100)

# Unit: milliseconds
#             expr      min       lq   median       uq      max neval
#     whichpart(x) 45.44940 46.15011 46.51321 48.67986 80.63286   100
#  whichpartrev(x) 28.84482 31.30661 32.87695 62.37843 67.84757   100

microbenchmark(whichpart(y), whichpartrev(y), times=100)

# Unit: milliseconds
#             expr      min       lq   median       uq      max neval
#     whichpart(y) 11.56135 12.26539 13.05729 13.75199 43.78484   100
#  whichpartrev(y) 16.00612 16.73690 17.71687 19.04153 49.02842   100


回答2:

vec <- runif(1000000)
index <- which(vec >= sort(vec, decreasing=T)[30], arr.ind=TRUE)
vec[index]


回答3:

I'm not sure how you'd avoid sorting. I think ?which.max might help.

Anyway, I'd do something like the following:

require(data.table)
set.seed(42)
n <- 1000000
x <- rnorm(n)
dt <- data.table(x = x, x.index = seq.int(1, n))
setkey(dt, "x")
tail(dt, 30)

#            x x.index
# 1: 0.9999712  270177
# 2: 0.9999715  521060
# 3: 0.9999723  863876
# 4: 0.9999757  622734
# 5: 0.9999761   48337
# 6: 0.9999764  699984
# 7: 0.9999766  264473
# 8: 0.9999770  212981
# 9: 0.9999782  911943
# 10: 0.9999874  330250
# 11: 0.9999876  695213
# 12: 0.9999879  219101
# 13: 0.9999880  144000
# 14: 0.9999880  459676
# 15: 0.9999887  910525
# 16: 0.9999894  902172
# 17: 0.9999900  474633
# 18: 0.9999905  360481
# 19: 0.9999920  985058
# 20: 0.9999925   17169
# 21: 0.9999926  424703
# 22: 0.9999927  448196
# 23: 0.9999929  254084
# 24: 0.9999932  468090
# 25: 0.9999940  480390
# 26: 0.9999961  765489
# 27: 0.9999966  556407
# 28: 0.9999968  860100
# 29: 0.9999982  879843
# 30: 0.9999989  507889


回答4:

I've managed to save a couple seconds using a vector of length 10^7 or more with this simple function:

top <- function(x, n=30){
    result <- numeric()
    for(i in 1:n){
        j <- which.max(x)
        result[i] <- x[j]
        x[j] <- -Inf
    }
    result
}

My results:

> x <- runif(1e7)
> system.time(y <- sort(x,decreasing=TRUE)[1:30])
   user  system elapsed 
   3.30    0.04    3.39 
> system.time(z <- top(x))
   user  system elapsed 
   2.49    0.58    3.12 
> x <- runif(1e8)
> system.time(y <- sort(x,decreasing=TRUE)[1:30])
   user  system elapsed 
  41.74    1.15   43.62 
> system.time(z <- top(x))
   user  system elapsed 
  25.96    7.61   34.43 


回答5:

set.seed(42)
x <- rnorm(1e6)

fun1 <- function(x) x[order(x, decreasing = TRUE)][1:30]

library(compiler)
fun2 <- cmpfun(function(x) {
  y <- rep(-Inf,30)
  for (i in seq_along(x)) {
    if (x[i] > y[1]) {
      y[1] <- x[i]
      y <- y[order(y)]
    }
  }
  y
})

library(microbenchmark)
microbenchmark(
y1 <- fun1(x),
y2 <- fun2(x),
times=5)
#Unit: milliseconds
#         expr      min       lq   median       uq      max neval
#y1 <- fun1(x) 400.1574 411.8172 418.7872 425.8027 426.2981     5
#y2 <- fun2(x) 255.7817 258.2374 258.8088 259.4630 290.6068     5


identical(sort(y1), sort(y2))
#[1] TRUE


回答6:

I would sort descending then just take the n first:

 x = rnorm(1000000, 10, 5)
 x = sort(x, decreasing = TRUE)
 n = 30
 print(head(x, n))


标签: r sorting vector