Numba: calling jit with explicit signature using a

2020-07-11 09:06发布

I'm using numba to make some functions containing cycles on numpy arrays.

Everything is fine and dandy, I can use jit and I learned how to define the signature.

Now I tried using jit on a function with optional arguments, e.g.:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float))'])
def fun(a, b=3):
    return a + b

This works, but if instead of optional(float) I use optional(float64) it doesn't (same thing with int or int64). I lost 1 hour trying to figure this syntax out (actually, a friend of mine found this solution by chance because he forgot to write the 64 after the float), but, for the love of me, I cannot understand why this is so. I can't find anything on the internet and numba's docs on the topic are scarce at best (and they specify that optional should take a numba type).

Does anyone know how this works? What am I missing?

1条回答
神经病院院长
2楼-- · 2020-07-11 09:22

Ah, but the exception message should give a hint:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3.0)

That means optional is the wrong choice here. In fact optional represents None or "that type". But you want an optional argument, not an argument that could be a float and None, e.g.:

>>> fun(10, None)  # doesn't fail because of the signature!
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

I suspect that it just "happens" to work for optional(float) because float is just an "arbitary Python object" from numbas point of view, so with optional(float) you could pass anything in there (this apparently includs not giving the argument). With optional(float64) it could only be None or a float64. That category isn't broad enough to allow not providing the argument.

It works if you give the type Omitted:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
13.0

However it seems like Omitted isn't actually included in the documentation and that it has some "rough edges". For example it can't be compiled in nopython mode with that signature, even though it seems possible without signature:

@njit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3):
    return a + b

TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (float64, class(float64))

-----------

@njit(['float64(float64, float64)', 'float64(float64, Omitted(3.))'])
def fun(a, b=3):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3)

-----------

@njit
def fun(a, b=3):
    return a + b

>>> fun(10.)
13.0
查看更多
登录 后发表回答