Avoiding eval(parse()) in building fractio

2019-07-03 11:35发布

问题:

My goal is to write a function in R that accepts coefficients for a fractional polynomial (FP) and returns a vectorized function which evaluates the specified FP for given input numbers. The FP definition has two important rules:

  • x^0 is defined as log(x)
  • powers can have multiple coefficients, where the 2nd coefficient for power p adds a factor of log(x) to the additive term (x^p*log(x)), the 3rd adds a factor of log(x)^2 (x^p*log(x)^2), and so on

My current solution below builds the FP-function as a string, parses the string and returns a function which evaluates the expression. My question is if there is a better/faster way that avoids eval(parse()) - possibly using some substitute() magic.

The function must deal with having the number of coefficients per power not known in advance, but specified when being called. The final FP evaluation needs to be fast as it is called very often.

It would be nice not to be limited to the standard powers -2, -1, -0.5, 0, 0.5, 1, 2, 3. Ideally, the desired function would do two steps at once: accept FP-coefficients as well as a vector of numbers and return the FP-values for the input while still being fast.

getFP <- function(p_2, p_1, p_0.5, p0, p0.5, p1, p2, p3, ...) {
    p <- as.list(match.call(expand.dots=TRUE)[-1])         # all args
    names(p) <- sub("^p", "", names(p))     # strip "p" from arg names
    names(p) <- sub("_", "-", names(p))     # replace _ by - in arg names

    ## for one power and the i-th coefficient: build string
    getCoefStr <- function(i, pow, coefs) {
        powBT  <- ifelse(as.numeric(pow), paste0("x^(", pow, ")"), "log(x)")
        logFac <- ifelse(i-1,             paste0("*log(x)^", i-1), "")
        paste0("(", coefs[i], ")*", powBT, logFac)
    }

    onePwrStr <- function(pow, p) { # for one power: build string for all coefs
        coefs  <- eval(p[[pow]])
        pwrStr <- sapply(seq(along=coefs), getCoefStr, pow, coefs)
        paste(pwrStr, collapse=" + ")
    }

    allPwrs <- sapply(names(p), onePwrStr, p)  # for each power: build string
    fpExpr  <- parse(text=paste(allPwrs, collapse=" + "))
    function(x) { eval(fpExpr) }
}

An example would be -1.5*x^(-1) - 14*log(x) - 13*x^(0.5) + 6*x^0.5*log(x) + 1*x^3 which has specified powers (-1, 0, 0.5, 0.5, 3) with coefficients (-1.5, -14, -13, 6, 1).

> fp <- getFP(p_1=-1.5, p0=-14, p0.5=c(-13, 6), p3=1)
> fp(1:3)
[1] -13.50000000 -14.95728798   0.01988127

回答1:

First we create a function that will generate a single term in the sequence

one <- function(p, c = 1, repeated = 1) {
  if (p == 0) {
    lhs <- substitute(c * log(x), list(c = c))
  } else {
    lhs <- substitute(c * x ^ p, list(c = c, p = p))
  }

  if (repeated == 1) return(lhs)
  substitute(lhs * log(x) ^ pow, list(lhs = lhs, pow = repeated - 1))
}
one(0)
# 1 * log(x)
one(2)
# 1 * x^2

one(2, 2)
# 2 * x^2

one(2, r = 2)
# 1 * x ^ 2 * log(x)^1
one(2, r = 3)
# 1 * x ^ 2 * log(x)^2

The key tool here is substitute() which is explained here.

Next we write a function that will add together two terms. Again this uses substitute:

add_expr_1 <- function(x, y) {
  substitute(x + y, list(x = x, y = y))
}

add_expr_1(one(0, 1), one(2, 1))

We can use this to make a function to add together any number of terms:

add_expr <- function(x) Reduce(add_expr_1, x)
add_expr(list(one(0, 1), one(1, 1), one(2, 3)))

With these piece in place, the final function is simple - we figure out the number of reps, then use Map() to call one() once for each combination of powers, coefs and reps:

fp <- function(powers, coefs) {
  # Determine number of times each power is repeated. This is too
  # clever approach but I think it works
  reps <- ave(powers, powers, FUN = seq_along)

  # Now generate a list of expressions using one
  components <- Map(one, powers, coefs, reps)

  # And combine them together with plus
  add_expr(components)
}

powers <- c(-1, 0, 0.5, 0.5, 3)
coefs <-  c(-1.5, -14, -13, 6, 1)
fp(powers, coefs)
# -1.5 * x^-1 + -14 * log(x) + -13 * x^0.5 + 6 * x^0.5 * log(x)^1 + 
#   1 * x^3


标签: r eval