How do I infer the class to which a @staticmethod

2019-05-03 23:45发布

问题:

I am trying to implement infer_class function that, given a method, figures out the class to which the method belongs.

So far I have something like this:

import inspect

def infer_class(f):
    if inspect.ismethod(f):
        return f.im_self if f.im_class == type else f.im_class
    # elif ... what about staticmethod-s?
    else:
        raise TypeError("Can't infer the class of %r" % f)

It does not work for @staticmethod-s because I was not able to come up with a way to achieve this.

Any suggestions?

Here's infer_class in action:

>>> class Wolf(object):
...     @classmethod
...     def huff(cls, a, b, c):
...         pass
...     def snarl(self):
...         pass
...     @staticmethod
...     def puff(k,l, m):
...         pass
... 
>>> print infer_class(Wolf.huff)
<class '__main__.Wolf'>
>>> print infer_class(Wolf().huff)
<class '__main__.Wolf'>
>>> print infer_class(Wolf.snarl)
<class '__main__.Wolf'>
>>> print infer_class(Wolf().snarl)
<class '__main__.Wolf'>
>>> print infer_class(Wolf.puff)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 6, in infer_class
TypeError: Can't infer the class of <function puff at ...>

回答1:

That's because staticmethods really aren't methods. The staticmethod descriptor returns the original function as is. There is no way to get the class via which the function was accessed. But there is no real reason to use staticmethods for methods anyway, always use classmethods.

The only use that I have found for staticmethods is to store function objects as class attributes and not have them turn into methods.



回答2:

I have trouble bringing myself to actually recommend this, but it does seem to work for straightforward cases, at least:

import inspect

def crack_staticmethod(sm):
    """
    Returns (class, attribute name) for `sm` if `sm` is a
    @staticmethod.
    """
    mod = inspect.getmodule(sm)
    for classname in dir(mod):
        cls = getattr(mod, classname, None)
        if cls is not None:
            try:
                ca = inspect.classify_class_attrs(cls)
                for attribute in ca:
                    o = attribute.object
                    if isinstance(o, staticmethod) and getattr(cls, sm.__name__) == sm:
                        return (cls, sm.__name__)
            except AttributeError:
                pass