Efficient way of calculating quadratic forms: avoi

2019-08-02 22:47发布

问题:

I want to calculate N (N is big) quadratic forms. I am using the command 'quad.form' from the R package 'emulator'. How can I implement this without using a for loop?

So far, I am using

library(emulator)

A = matrix(1,ncol=5,nrow=5) # A matrix

x = matrix(1:25,ncol=5,nrow=5) # The vectors of interest in the QF

# for loop
QF = vector()
for(i in 1:5){
QF[i] = quad.form(A,x[,i])
}

Is there a more direct and efficient way to calculate these quadratic forms?

Something intriguing is that

quad.form(A,x)

is (10 times) faster than the for loop, but I only need the diagonal of this outcome. So, it would still be an inefficient way of calculating the N quadratic forms of interest.

回答1:

How about

colSums(x * (A %*% x))

? Gets the right answer for this example at least ... and should be much faster!

library("rbenchmark")
A <- matrix(1, ncol=500, nrow=500)
x <- matrix(1:25, ncol=500, nrow=500)

library("emulator")
aa <- function(A,x) apply(x, 2, function (y) quad.form(A,y))
cs <- function(A,x) colSums(x * (A %*% x))
dq <- function(A,x) diag(quad.form(A,x))
all.equal(cs(A,x),dq(A,x))  ## TRUE
all.equal(cs(A,x),aa(A,x))  ## TRUE
benchmark(aa(A,x),
          cs(A,x),
          dq(A,x))
##       test replications elapsed relative user.self sys.self
## 1 aa(A, x)          100  13.121    1.346    13.085    0.024
## 2 cs(A, x)          100   9.746    1.000     9.521    0.224
## 3 dq(A, x)          100  26.369    2.706    25.773    0.592


回答2:

Use the apply function:

apply(x, 2, function (y) quad.form(A,y))

If you make the matrices larger (500x500) it becomes clear that using apply is roughly twice as fast than using quad.form(A,x):

A <- matrix(1, ncol=500, nrow=500)
x <- matrix(1:25, ncol=500, nrow=500)

system.time(apply(x, 2, function (y) quad.form(A,y)))
# user  system elapsed 
# 0.183   0.000   0.183 

system.time(quad.form(A,x))
# user  system elapsed 
# 0.314   0.000   0.314 

EDIT

And @Ben Bolker's answer is about 1/3 faster than apply:

system.time(colSums(x * (A %*% x)))
# user  system elapsed 
# 0.123   0.000   0.123 


标签: r quadratic