Optimize a recursive function in Julia

2019-08-28 00:25发布

I wrote a Julia code which computes integrals over Gaussian functions and I have a sort-of kernel function which is called over and over again. According to the Julia built-in Profile Module, this is where I spend most of the time during the actual computation and therefore I would like to see if there is any way in which I can improve it.

It is a recursive function and I implemented it in a kind of straightforward way. As I am not that much used to recursive functions, maybe somebody out there has some ideas/suggestions on how to improve it (both from a purely theoretical algorithmic point of view and/or exploiting special optimizations from the JIT compiler).

Here you have it:

"""Returns the integral of an Hermite Gaussian divided by the Coulomb operator."""
function Rtuv{T<:Real}(t::Int, u::Int, v::Int, n::Int, p::Real, RPC::Vector{T})
    if t == u == v == 0
        return (-2.0*p)^n * boys(n,p*norm(RPC)^2)
    elseif u == v == 0
        if t > 1
            return  (t-1)*Rtuv(t-2, u, v, n+1, p, RPC) +
                   RPC[1]*Rtuv(t-1, u, v, n+1, p, RPC)
        else
            return RPC[1]*Rtuv(t-1, u, v, n+1, p, RPC)
        end
    elseif v == 0
        if u > 1
            return  (u-1)*Rtuv(t, u-2, v, n+1, p, RPC) +
                   RPC[2]*Rtuv(t, u-1, v, n+1, p, RPC)
        else
            return RPC[2]*Rtuv(t, u-1, v, n+1, p ,RPC)
        end
    else
        if v > 1
            return  (v-1)*Rtuv(t, u, v-2, n+1, p, RPC)
                   RPC[3]*Rtuv(t, u, v-1, n+1, p, RPC)
        else
            return RPC[3]*Rtuv(t, u, v-1, n+1, p, RPC)
        end
    end
end

Don't pay that much attention to the function boys, since according to the profiler it is not that heavy.
Just to give an idea of the range of numbers: usually the first call comes from t+u+v ranging from 0 to 3, while n always starts at 0.

Cheers

EDIT -- New information

The generated version is slower for small values of t,u,v, I believe the reason is because expressions are not optimzied by the compiler. I was benchmarking badly for this case, without interpolating the argument passed. By doing it properly I am always faster with the approach explained in the accepted answer, so hurray!

More generally, does the compiler identify trivial cases such as multiplication by zeros and ones and optimize those away?

Answer to myself: from a quick checking of simple code with @code_llvm it seems not to be the case.

1条回答
beautiful°
2楼-- · 2019-08-28 00:56

Maybe this works in your case: you can "memoize" whole compiled methods using generated functions and get rid of all recursion after the first call.

Since t, u, and v will stay small, you could generate the fully expanded code for the recursions. Assume for the simplicity a bogus implementation of

boys(n::Int, x::Real) = n + x

Then

function Rtuv_expr(t::Int, u::Int, v::Int, n, p, RPC)
    ninc = :($n + 1)

    if t == u == v == 0
        :((-2.0 * $p)^$n * boys($n, $p * norm($RPC)^2))
    elseif u == v == 0
        if t > 1
            :($(t-1) * $(Rtuv_expr(t-2, u, v, ninc, p, RPC)) +
              $RPC[1] * $(Rtuv_expr(t-1, u, v, ninc, p, RPC)))
        else
            :($RPC[1] * $(Rtuv_expr(t-1, u, v, ninc, p, RPC)))
        end
    elseif v == 0
        if u > 1
            :($(u-1) * $(Rtuv_expr(t, u-2, v, ninc, p, RPC)) +
              $RPC[2] * $(Rtuv_expr(t, u-1, v, ninc, p, RPC)))
        else
            :($RPC[2] * $(Rtuv_expr(t, u-1, v, ninc, p, RPC)))
        end
    else
        if v > 1 
            :($(v-1) * $(Rtuv_expr(t, u, v-2, ninc, p, RPC)) + 
              $RPC[3] * $(Rtuv_expr(t, u, v-1, ninc, p, RPC)))
        else
            :($RPC[3] * $(Rtuv_expr(t, u, v-1, ninc, p, RPC)))
        end
    end
end

will generate you fully expanded expressions like this:

julia> Rtuv_expr(1, 2, 1, 0, 0.1, rand(3))
:(([0.868194, 0.928591, 0.295344])[3] * (1 * (([0.868194, 0.928591, 0.295344])[1] * ((-2.0 * 0.1) ^ (((0 + 1) + 1) + 1) * boys(((0 + 1) + 1) + 1, 0.1 * norm([0.868194, 0.928591, 0.295344]) ^ 2))) + ([0.868194, 0.928591, 0.295344])[2] * (([0.868194, 0.928591, 0.295344])[2] * (([0.868194, 0.928591, 0.295344])[1] * ((-2.0 * 0.1) ^ ((((0 + 1) + 1) + 1) + 1) * boys((((0 + 1) + 1) + 1) + 1, 0.1 * norm([0.868194, 0.928591, 0.295344]) ^ 2))))))

We can stuff that into a generated function Rtuv taking Val types. For each different combination of T, U, and V, this function will use Rtuv_expr to compile the respective expression and from then on use this method -- no recursion anymore:

@generated function Rtuv{T, U, V, X<:Real}(::Type{Val{T}}, ::Type{Val{U}}, ::Type{Val{V}},
                                           n::Int, p::Real, RPC::Vector{X})
    Rtuv_expr(T, U, V, :n, :p, :RPC)
end

You have to call it with t, u, v wrapped in Val, though:

julia> Rtuv(Val{1}, Val{2}, Val{1}, 0, 0.1, rand(3))
-0.0007782250832001092

If you test a small loop like this,

for t = 0:3, u = 0:3, v = 0:3
    println(Rtuv(Val{t}, Val{u}, Val{v}, 0, 0.1, [1.0, 2.0, 3.0]))
end

it will need some time for the first run, but then go pretty fast, since the used methods are already compiled.

查看更多
登录 后发表回答