dealing with types in kwargs in Julia

2020-07-22 20:21发布

问题:

How can I use kwargs in a Julia function and declare their types for speed?

function f(x::Float64; kwargs...)
    kwargs = Dict(kwargs)
    if haskey(kwargs, :c)
        c::Float64 = kwargs[:c]
    else
        c::Float64 = 1.0
    end
    return x^2 + c
end

f(0.0, c=10.0)

yields:

ERROR: LoadError: syntax: multiple type declarations for "c"

Of course I can define the function as f(x::Float64, c::Float64=1.0) to achieve the result, but I have MANY optional arguments with default values to pass, so I'd prefer to use kwargs.

Thanks.

Related post

回答1:

As noted in another answer, this really only matters if you're going to have a type instability. If you do, the answer is to layer your functions. Have a top layer which does type checking and all sorts of setup, and then call a function which uses dispatch to be fast. For example,

function f(x::Float64; kwargs...)
    kwargs = Dict(kwargs)
    if haskey(kwargs, :c)
        c = kwargs[:c]
    else
        c = 1.0
    end

    return _f(x,c)
end
_f(x,c) = x^2 + c

If most of your time is spent in the inner function, then this will be faster (it might not be for very simple functions). This allows for very general usage too, where you have have a keyword argument be by default nothing and do and if nothing ... which could setup a complicated default, and not have to worry about the type stability since it will be shielded from the inner function.

This kind of high-level type-checking wrapper above a performance sensitive inner function is used a lot in DifferentialEquations.jl. Check out the high-level wrapper for the SDE solvers which led to nice speedups by insuring type stability (the inner function is sde_solve) (or check out the solve for ODEProblem, it's much more complex since it handles conversions to different pacakges but it's the same idea).

A simpler answer for small examples like yours may be possible after this PR merges.


To fix some confusion, here's a declaration form:

function f(x::Float64; kwargs...)
    local c::Float64 # Ensures the type of `c` will be `Float64`
    kwargs = Dict(kwargs)
    if haskey(kwargs, :c)
        c = float(kwargs[:c])
    else
        c = 1.0
    end

    return x^2 + c
end

This will force anything that saves to c to convert to a Float64 or error, resulting in a type-stability, but is not as general of a solution. What form you use really depends on what you're doing.

Lastly, there's also the type assert, as @TotalVerb showed:

function f(x::Float64; c::Float64=1.0, kwargs...)
   return x^2 + c
end

That's clean, or you could assert in the function:

  function f(x::Float64; kwargs...)
    kwargs = Dict(kwargs)
    if haskey(kwargs, :c)
        c = float(kwargs[:c])::Float64
    else
        c = 1.0
    end

    return x^2 + c
end

which will cause convertions only on the lines where the assertion occurs (i.e. the @TotalVerb form won't dispatch, so you can't make another function with c::Int, and it will only assert (convert) when the keyword arg is first read in).


Summary

  1. The first solution will dispatch to be type stable in _f no matter what type the user makes c, and so if _f is a long calculation, this will get pretty much optimal performance, but for really quick calls it will have dispatch overhead.

  2. The second solution will fix any type stability by forcing anything you set c to be a Float64 (it will try to convert, and if it can't, error). Thus this gets speed by forcing type stability, or erroring.

  3. The assert in the keyword spot (@TotalVerb's answer) is the cleanest, but won't auto-convert later (so you could get a type-instability. But if you don't accidentally convert it later, then you have type stability, types can be inferred, and so you'll get optimal performance) and you can't extend it to cases where the function has c passed in as other types (no dispatch).

  4. The last solution is pretty much the same as 3, except not as nice. I wouldn't recommend it. If you're doing something complicated with asserts, you likely are designing something wrong or really want to do something like the first (dispatch in a longer function call which is type stable).

But note that dispatch with version 3 may be fixed in the near future, which would allow you to have a different function with c::Float64 and c::Int (if necessary). Hopefully your solution is in here somewhere.



回答2:

Note that declaring types does not give you increased performance; you may wish to relax the type constraints on x and c for your code to be more generic. Anyway, this is probably what you want:

function f(x::Float64; c::Float64=1.0, kwargs...)
    return x^2 + c
end

See the keyword arguments section of the manual.



标签: julia