Overloading (or alternatives) in Python API design

2020-02-29 02:35发布

问题:

I have a large existing program library that currently has a .NET binding, and I'm thinking about writing a Python binding. The existing API makes extensive use of signature-based overloading. So, I have a large collection of static functions like:

Circle(p1, p2, p3) -- Creates a circle through three points
Circle(p, r)       -- Creates a circle with given center point and radius
Circle(c1, c2, c3) -- Creates a circle tangent to three curves

There are a few cases where the same inputs must be used in different ways, so signature-based overloading doesn't work, and I have to use different function names, instead. For example

BezierCurve(p1,p2,p3,p4) -- Bezier curve using given points as control points
BezierCurveThroughPoints(p1,p2,p3,p4) -- Bezier curve passing through given points

I suppose this second technique (using different function names) could be used everywhere in the Python API. So, I would have

CircleThroughThreePoints(p1, p2, p3)
CircleCenterRadius(p, r)
CircleTangentThreeCurves(c1, c2, c3)

But the names look unpleasantly verbose (I don't like abbreviations), and inventing all of them will be quite a challenge, because the library has thousands of functions.

Low Priorities:
Effort (on my part) -- I don't care if I have to write a lot of code.
Performance

High Priorities:
Ease of use/understanding for callers (many will be programming newbies).
Easy for me to write good documentation.
Simplicity -- avoid the need for advanced concepts in caller's code.

I'm sure I'm not the first person who ever wished for signature-based overloading in Python. What work-arounds do people typically use?

回答1:

One option is to exclusively keyword arguments in the constructor, and include logic to figure out what should be used:

class Circle(object):
    def __init__(self, points=(), radius=None, curves=()):
        if radius and len(points) == 1:
            center_point = points[0]
            # Create from radius/center point
        elif curves and len(curves) == 3:
            # create from curves
        elif points and len(points) == 3:
            # create from points
        else:
            raise ValueError("Must provide a tuple of three points, a point and a radius, or a tuple of three curves)

You can also use classmethods to make things easier for the users of the API:

class Circle(object):
    def __init__(self, points=(), radius=None, curves=()):
         # same as above

    @classmethod
    def from_points(p1, p2, p3):
        return cls(points=(p1, p2, p3))

    @classmethod
    def from_point_and_radius(cls, point, radius):
        return cls(points=(point,), radius=radius)

    @classmethod
    def from_curves(cls, c1, c2, c3):
        return cls(curves=(c1, c2, c3))

Usage:

c = Circle.from_points(p1, p2, p3)
c = Circle.from_point_and_radius(p1, r)
c = Circle.from_curves(c1, c2, c3)


回答2:

There are a couple of options.

You can have one constructor that accepts and arbitrary number of arguments (with *args and/or **varargs syntaxes) and does different things depending on the number and type the arguments have.

Or, you can write secondary constructors as class methods. These are known as "factory" methods. If you have multiple constructors that take the same number of objects of the same classes (as in your BezierCurve example), this is probably your only option.

If you don't mind overriding __new__ rather than __init__, you can even have both, with the __new__ method handling one form of arguments by itself and referring other kinds to the factory methods for regularizing. Here's an example of what that might look like, including doc strings for the multiple signatures to __new__:

class Circle(object):
    """Circle(center, radius) -> Circle object
       Circle(point1, point2, point3) -> Circle object
       Circle(curve1, curve2, curve3) -> Circle object

       Return a Circle with the provided center and radius. If three points are given,
       the center and radius will be computed so that the circle will pass through each
       of the points. If three curves are given, the circle's center and radius will
       be chosen so that the circle will be tangent to each of them."""

    def __new__(cls, *args):
        if len(args) == 2:
            self = super(Circle, cls).__new__(cls)
            self.center, self.radius = args
            return self
        elif len(args) == 3:
            if all(isinstance(arg, Point) for arg in args):
                return Circle.through_points(*args)
            elif all(isinstance(arg, Curve) for arg in args):
                return Circle.tangent_to_curves(*args)
        raise TypeError("Invalid arguments to Circle()")

    @classmethod
    def through_points(cls, point1, point2, point3):
        """from_points(point1, point2, point3) -> Circle object

        Return a Circle that touches three points."""

        # compute center and radius from the points...
        # then call back to the main constructor:
        return cls(center, radius)

    @classmethod
    def tangent_to_curves(cls, curve1, curve2, curve3):
        """from_curves(curve1, curve2, curve3) -> Circle object

        Return a Circle that is tangent to three curves."""

        # here too, compute center and radius from curves ...
        # then call back to the main constructor:
        return cls(center, radius)


回答3:

You could use a dictionary, like so

Circle({'points':[p1,p2,p3]})
Circle({'radius':r})
Circle({'curves':[c1,c2,c3])

And the initializer would say

def __init__(args):
  if len(args)>1:
    raise SomeError("only pass one of points, radius, curves")
  if 'points' in args: {blah}
  elsif 'radius' in args: {blahblah}
  elsif 'curves' in args: {evenmoreblah}
  else: raise SomeError("same as above")


回答4:

One way would be to just write code parse the args yourself. Then you wouldn't have to change the API at all. You could even write a decorator so it'd be reusable:

import functools

def overload(func):
  '''Creates a signature from the arguments passed to the decorated function and passes it as the first argument'''
  @functools.wraps(func)
  def inner(*args):
    signature = tuple(map(type, args))
    return func(signature, *args)
  return inner

def matches(collection, sig):
  '''Returns True if each item in collection is an instance of its respective item in signature'''
  if len(sig)!=len(collection): 
    return False
  return all(issubclass(i, j) for i,j in zip(collection, sig))

@overload
def Circle1(sig, *args):  
  if matches(sig, (Point,)*3):
    #do stuff with args
    print "3 points"
  elif matches(sig, (Point, float)):
    #as before
    print "point, float"
  elif matches(sig, (Curve,)*3):
    #and again
    print "3 curves"
  else:
    raise TypeError("Invalid argument signature")

# or even better
@overload
def Circle2(sig, *args):
  valid_sigs = {(Point,)*3: CircleThroughThreePoints,
                (Point, float): CircleCenterRadius,
                (Curve,)*3: CircleTangentThreeCurves
               }
  try:  
    return (f for s,f in valid_sigs.items() if matches(sig, s)).next()(*args)
  except StopIteration:
    raise TypeError("Invalid argument signature")

How it appears to API users:

This is the best part. To an API user, they just see this:

>>> help(Circle)

Circle(*args)
  Whatever's in Circle's docstring. You should put info here about valid signatures.

They can just call Circle like you showed in your question.

How it works:

The whole idea is to hide the signature-matching from the API. This is accomplished by using a decorator to create a signature, basically a tuple containing the types of each of the arguments, and passing that as the first argument to the functions.

overload:

When you decorate a function with @overload, overload is called with that function as an argument. Whatever is returned (in this case inner) replaces the decorated function. functools.wraps ensures that the new function has the same name, docstring, etc.

Overload is a fairly simple decorator. All it does is make a tuple of the types of each argument and pass that tuple as the first argument to the decorated function.

Circle take 1:

This is the simplest approach. At the beginning of the function, just test the signature against all valid ones.

Circle take 2:

This is a little more fancy. The benefit is that you can define all of your valid signatures together in one place. The return statement uses a generator to filter the matching valid signature from the dictionary, and .next() just gets the first one. Since that entire statement returns a function, you can just stick a () afterwards to call it. If none of the valid signatures match, .next() raises a StopIteration.

All in all, this function just returns the result of the function with the matching signature.

final notes:

One thing you see a lot in this bit of code is the *args construct. When used in a function definition, it just stores all the arguments in a list named "args". Elsewhere, it expands a list named args so that each item becomes an argument to a function (e.g. a = func(*args)).

I don't think it's terribly uncommon to do odd things like this to present clean APIs in Python.



回答5:

There are a number of modules in PyPI that can help you with signature based overloading and dispatch: multipledispatch, multimethods, Dispatching - none of which I have real experience with, but multipledispatch looks like what you want and it's well documented. Using your circle example:

from multipledispatch import dispatch

class Point(tuple):
    pass

class Curve(object):         
    pass

@dispatch(Point, Point, Point)
def Circle(point1, point2, point3):
    print "Circle(point1, point2, point3): point1 = %r, point2 = %r, point3 = %r" % (point1, point2, point3)

@dispatch(Point, int)
def Circle(centre, radius):
    print "Circle(centre, radius): centre = %r, radius = %r" % (centre, radius)

@dispatch(Curve, Curve, Curve)
def Circle(curve1, curve2, curve3):
    print "Circle(curve1, curve2, curve3): curve1 = %r, curve2 = %r, curve3 = %r" % (curve1, curve2, curve3)


>>> Circle(Point((10,10)), Point((20,20)), Point((30,30)))
Circle(point1, point2, point3): point1 = (10, 10), point2 = (20, 20), point3 = (30, 30)
>>> p1 = Point((25,10))
>>> p1
(10, 10)
>>> Circle(p1, 100)
Circle(centre, radius): centre = (25, 10), radius = 100

>>> Circle(*(Curve(),)*3)
Circle(curve1, curve2, curve3): curve1 = <__main__.Curve object at 0xa954d0>, curve2 = <__main__.Curve object at 0xa954d0>, curve3 = <__main__.Curve object at 0xa954d0>

>>> Circle()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 143, in __call__
    func = self.resolve(types)
  File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 184, in resolve
    (self.name, str_signature(types)))
NotImplementedError: Could not find signature for Circle: <>

It's also possible to decorate instance methods, so you can provide multiple implementations of __init__(), which is quite nice. If you were implementing any actual behaviour within the class, e.g. Circle.draw(), you would need some logic to work out what values are available with to draw the circle (centre and radius, 3 points, etc). But as this is just to provide a set of bindings, you probably only need to call the correct native code function and pass on the parameters :

from numbers import Number
from multipledispatch import dispatch

class Point(tuple):
    pass

class Curve(object):
    pass

class Circle(object):
    "A circle class"

    # dispatch(Point, (int, float, Decimal....))
    @dispatch(Point, Number)
    def __init__(self, centre, radius):
        """Circle(Point, Number): create a circle from a Point and radius."""

        print "Circle.__init__(): centre %r, radius %r" % (centre, radius)

    @dispatch(Point, Point, Point)
    def __init__(self, point1, point2, point3):
        """Circle(Point, Point, Point): create a circle from 3 points."""

        print "Circle.__init__(): point1 %r, point2 %r, point3 = %r" % (point1, point2, point3)

    @dispatch(Curve, Curve, Curve)
    def __init__(self, curve1, curve2, curve3):
        """Circle(Curve, Curve, Curve): create a circle from 3 curves."""

        print "Circle.__init__(): curve1 %r, curve2 %r, curve3 = %r" % (curve1, curve2, curve3)

    __doc__ = '' if __doc__ is None else '{}\n\n'.format(__doc__)
    __doc__ += '\n'.join(f.__doc__ for f in __init__.funcs.values())


>>> print Circle.__doc__
A circle class

Circle(Point, Number): create a circle from a Point and radius.
Circle(Point, Point, Point): create a circle from 3 points.
Circle(Curve, Curve, Curve): create a circle from 3 curves.

>>> for num in 10, 10.22, complex(10.22), True, Decimal(100):
...     Circle(Point((10,20)), num)
... 
Circle.__init__(): centre (10, 20), radius 10
<__main__.Circle object at 0x1d42fd0>
Circle.__init__(): centre (10, 20), radius 10.22
<__main__.Circle object at 0x1e3d890>
Circle.__init__(): centre (10, 20), radius (10.22+0j)
<__main__.Circle object at 0x1d42fd0>
Circle.__init__(): centre (10, 20), radius True
<__main__.Circle object at 0x1e3d890>
Circle.__init__(): centre (10, 20), radius Decimal('100')
<__main__.Circle object at 0x1d42fd0>

>>> Circle(Curve(), Curve(), Curve())
Circle.__init__(): curve1 <__main__.Curve object at 0x1e3db50>, curve2 <__main__.Curve object at 0x1d42fd0>, curve3 = <__main__.Curve object at 0x1d4b1d0>
<__main__.Circle object at 0x1d4b4d0>

>>> p1=Point((10,20))
>>> Circle(*(p1,)*3)
Circle.__init__(): point1 (10, 20), point2 (10, 20), point3 = (10, 20)
<__main__.Circle object at 0x1e3d890>

>>> Circle()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 235, in __call__
    func = self.resolve(types)
  File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 184, in resolve
    (self.name, str_signature(types)))
NotImplementedError: Could not find signature for __init__: <>