Perfect forwarding - in Python

2020-03-01 16:51发布

问题:

I am a maintainer of a Python project that makes heavy use of inheritance. There's an anti-pattern that has caused us a couple of issues and makes reading difficult, and I am looking for a good way to fix it.

The problem is forwarding very long argument lists from derived classes to base classes - mostly but not always in constructors.

Consider this artificial example:

class Base(object):
    def __init__(self, a=1, b=2, c=3, d=4, e=5, f=6, g=7):
       self.a = a
       # etc

class DerivedA(Base):
    def __init__(self, a=1, b=2, c=300, d=4, e=5, f=6, g=700, z=0):
        super().__init__(a=a, b=b, c=c, d=d, e=e, f=f, g=g)
        self.z = z

class DerivedB(Base):
    def __init__(self, z=0, c=300, g=700, **kwds):
        super().__init__(c=c, g=g, **kwds)
        self.z = z

At this time, everything looks like DerivedA - long argument lists, all of which are passed to the base class explicitly.

Unfortunately, we've had a couple of issues over the last couple of years, involving forgetting to pass an argument and getting the default, and from not noticing that one default parameter in one derived class was different from the default default.

It also makes the code needlessly bulky and therefore hard-to-read.

DerivedB is better and fixes those problems, but has the new problem that the Python help/sphinx HTML documentation for the method in the derived class is misleading as a lot of the important parameters are hidden in the **kwds.

Is there some way to "forward" the correct signature - or at least the documentation of the correct signature - from the base class method to the derived class method?

回答1:

I haven't found a way to perfectly create a function with the same signature, but I think the downsides of my implementation aren't too serious. The solution I've come up with is a function decorator.

Usage example:

class Base(object):
    def __init__(self, a=1, b=2, c=3, d=4, e=5, f=6, g=7):
       self.a = a
       # etc

class DerivedA(Base):
    @copysig(Base.__init__)
    def __init__(self, args, kwargs, z=0):
        super().__init__(*args, **kwargs)
        self.z = z

All named inherited parameters will be passed to the function through the kwargs dict. The args parameter is only used to pass varargs to the function. If the parent function has no varargs, args will always be an empty tuple.

Known problems and limitations:

  • Doesn't work in python2! (Why are you still using python 2?)
  • Not all attributes of the decorated function are perfectly preserved. For example, function.__code__.co_filename will be set to "<string>".
  • If the decorated functions throws an exception, there will be an additional function call visible in the exception traceback, for example:

    >>> f2() Traceback (most recent call last): File "", line 1, in File "", line 3, in f2 File "untitled.py", line 178, in f2 raise ValueError() ValueError

  • If a method is decorated, the first parameter must be called "self".

Implementation

import inspect

def copysig(from_func, *args_to_remove):
    def wrap(func):
        #add and remove parameters
        oldsig= inspect.signature(from_func)
        oldsig= _remove_args(oldsig, args_to_remove)
        newsig= _add_args(oldsig, func)

        #write some code for a function that we can exec
        #the function will have the correct signature and forward its arguments to the real function
        code= '''
def {name}{signature}:
    {func}({args})
'''.format(name=func.__name__,
            signature=newsig,
            func='_'+func.__name__,
            args=_forward_args(oldsig, newsig))
        globs= {'_'+func.__name__: func}
        exec(code, globs)
        newfunc= globs[func.__name__]

        #copy as many attributes as possible
        newfunc.__doc__= func.__doc__
        newfunc.__module__= func.__module__
        #~ newfunc.__closure__= func.__closure__
        #~ newfunc.__code__.co_filename= func.__code__.co_filename
        #~ newfunc.__code__.co_firstlineno= func.__code__.co_firstlineno
        return newfunc
    return wrap

def _collectargs(sig):
    """
    Writes code that gathers all parameters into "self" (if present), "args" and "kwargs"
    """
    arglist= list(sig.parameters.values())

    #check if the first parameter is "self"
    selfarg= ''
    if arglist:
        arg= arglist[0]
        if arg.name=='self':
            selfarg= 'self, '
            del arglist[0]

    #all named parameters will be passed as kwargs. args is only used for varargs.
    args= 'tuple(), '
    kwargs= ''
    kwarg= ''
    for arg in arglist:
        if arg.kind in (arg.POSITIONAL_ONLY,arg.POSITIONAL_OR_KEYWORD,arg.KEYWORD_ONLY):
            kwargs+= '("{0}",{0}), '.format(arg.name)
        elif arg.kind==arg.VAR_POSITIONAL:
            #~ assert not args
            args= arg.name+', '
        elif arg.kind==arg.VAR_KEYWORD:
            assert not kwarg
            kwarg= 'list({}.items())+'.format(arg.name)
        else:
            assert False, arg.kind
    kwargs= 'dict({}[{}])'.format(kwarg, kwargs[:-2])

    return '{}{}{}'.format(selfarg, args, kwargs)

def _forward_args(args_to_collect, sig):
    collect= _collectargs(args_to_collect)

    collected= {arg.name for arg in args_to_collect.parameters.values()}
    args= ''
    for arg in sig.parameters.values():
        if arg.name in collected:
            continue

        if arg.kind==arg.VAR_POSITIONAL:
            args+= '*{}, '.format(arg.name)
        elif arg.kind==arg.VAR_KEYWORD:
            args+= '**{}, '.format(arg.name)
        else:
            args+= '{0}={0}, '.format(arg.name)
    args= args[:-2]

    code= '{}, {}'.format(collect, args) if args else collect
    return code

def _remove_args(signature, args_to_remove):
    """
    Removes named parameters from a signature.
    """
    args_to_remove= set(args_to_remove)
    varargs_removed= False
    args= []
    for arg in signature.parameters.values():
        if arg.name in args_to_remove:
            if arg.kind==arg.VAR_POSITIONAL:
                varargs_removed= True
            continue

        if varargs_removed and arg.kind==arg.KEYWORD_ONLY:#if varargs have been removed, there are no more keyword-only parameters
            arg= arg.replace(kind=arg.POSITIONAL_OR_KEYWORD)

        args.append(arg)

    return signature.replace(parameters=args)

def _add_args(sig, func):
    """
    Merges a signature and a function into a signature that accepts ALL the parameters.
    """
    funcsig= inspect.signature(func)

    #find out where we want to insert the new parameters
    #parameters with a default value will be inserted before *args (if any)
    #if parameters with a default value exist, parameters with no default value will be inserted as keyword-only AFTER *args
    vararg= None
    kwarg= None
    insert_index_default= None
    insert_index_nodefault= None
    default_found= False
    args= list(sig.parameters.values())
    for index,arg in enumerate(args):
        if arg.kind==arg.VAR_POSITIONAL:
            vararg= arg
            insert_index_default= index
            if default_found:
                insert_index_nodefault= index+1
            else:
                insert_index_nodefault= index
        elif arg.kind==arg.VAR_KEYWORD:
            kwarg= arg
            if insert_index_default is None:
                insert_index_default= insert_index_nodefault= index
        else:
            if arg.default!=arg.empty:
                default_found= True

    if insert_index_default is None:
        insert_index_default= insert_index_nodefault= len(args)

    #find the new parameters
    #skip the first two parameters (args and kwargs)
    newargs= list(funcsig.parameters.values())
    if not newargs:
        raise Exception('The decorated function must accept at least 2 parameters')
    #if the first parameter is called "self", ignore the first 3 parameters
    if newargs[0].name=='self':
        del newargs[0]
    if len(newargs)<2:
        raise Exception('The decorated function must accept at least 2 parameters')
    newargs= newargs[2:]

    #add the new parameters
    if newargs:
        new_vararg= None
        for arg in newargs:
            if arg.kind==arg.VAR_POSITIONAL:
                if vararg is None:
                    new_vararg= arg
                else:
                    raise Exception('Cannot add varargs to a function that already has varargs')
            elif arg.kind==arg.VAR_KEYWORD:
                if kwarg is None:
                    args.append(arg)
                else:
                    raise Exception('Cannot add kwargs to a function that already has kwargs')
            else:
                #we can insert it as a positional parameter if it has a default value OR no other parameter has a default value
                if arg.default!=arg.empty or not default_found:
                    #do NOT change the parameter kind here. Leave it as it was, so that the order of varargs and keyword-only parameters is preserved.
                    args.insert(insert_index_default, arg)
                    insert_index_nodefault+= 1
                    insert_index_default+= 1
                else:
                    arg= arg.replace(kind=arg.KEYWORD_ONLY)
                    args.insert(insert_index_nodefault, arg)
                    if insert_index_default==insert_index_nodefault:
                        insert_index_default+= 1
                    insert_index_nodefault+= 1

        #if varargs need to be added, insert them before keyword-only arguments
        if new_vararg is not None:
            for i,arg in enumerate(args):
                if arg.kind not in (arg.POSITIONAL_ONLY,arg.POSITIONAL_OR_KEYWORD):
                    break
            else:
                i+= 1
            args.insert(i, new_vararg)

    return inspect.Signature(args, return_annotation=funcsig.return_annotation)

Short explanation:

The decorator creates a string of the form

def functionname(arg1, arg2, ...):
    real_function((arg1, arg2), {'arg3':arg3, 'arg4':arg4}, z=z)

then execs it and returns the dynamically created function.

Additional features:

If you don't want to "inherit" parameters x and y, use

@copysig(parentfunc, 'x', 'y')


回答2:

Consider using attrs module:

import attr

@attr.s
class Base(object):
    a = attr.ib(1)
    b = attr.ib(2)
    c = attr.ib(3)
    d = attr.ib(4)
    e = attr.ib(5)
    f = attr.ib(6)
    g = attr.ib(7)

@attr.s
class DerivedA(Base):
    z = attr.ib(0)

der_a = DerivedA()
print(der_a.a, der_a.z)