Python - how do I force the use of a factory metho

2019-07-24 22:28发布

I have a set of related classes that all inherit from one base class. I would like to use a factory method to instantiate objects for these classes. I want to do this because then I can store the objects in a dictionary keyed by the class name before returning the object to the caller. Then if there is a request for an object of a particular class, I can check to see whether one already exists in my dictionary. If not, I'll instantiate it and add it to the dictionary. If so, then I'll return the existing object from the dictionary. This will essentially turn all the classes in my module into singletons.

I want to do this because the base class that they all inherit from does some automatic wrapping of the functions in the subclasses, and I don't want to the functions to get wrapped more than once, which is what happens currently if two objects of the same class are created.

The only way I can think of doing this is to check the stacktrace in the __init__() method of the base class, which will always be called, and to throw an exception if the stacktrace does not show that the request to make the object is coming from the factory function.

Is this a good idea?

Edit: Here is the source code for my base class. I've been told that I need to figure out metaclasses to accomplish this more elegantly, but this is what I have for now. All Page objects use the same Selenium Webdriver instance, which is in the driver module imported at the top. This driver is very expensive to initialize -- it is initialized the first time a LoginPage is created. After it is initialized the initialize() method will return the existing driver instead of creating a new one. The idea is that the user must begin by creating a LoginPage. There will eventually be dozens of Page classes defined and they will be used by unit testing code to verify that the behavior of a website is correct.

from driver import get_driver, urlpath, initialize
from settings import urlpaths

class DriverPageMismatchException(Exception):
    pass

class URLVerifyingPage(object):
    # we add logic in __init__() to check the expected urlpath for the page
    # against the urlpath that the driver is showing - we only want the page's
    # methods to be invokable if the driver is actualy at the appropriate page.
    # If the driver shows a different urlpath than the page is supposed to
    # have, the method should throw a DriverPageMismatchException

    def __init__(self):
        self.driver = get_driver()
        self._adjust_methods(self.__class__)

    def _adjust_methods(self, cls):
        for attr, val in cls.__dict__.iteritems():
            if callable(val) and not attr.startswith("_"):
                print "adjusting:"+str(attr)+" - "+str(val)
                setattr(
                    cls,
                    attr,
                    self._add_wrapper_to_confirm_page_matches_driver(val)
                )
        for base in cls.__bases__:
            if base.__name__ == 'URLVerifyingPage': break
            self._adjust_methods(base)

    def _add_wrapper_to_confirm_page_matches_driver(self, page_method):
        def _wrapper(self, *args, **kwargs):
            if urlpath() != urlpaths[self.__class__.__name__]:
                raise DriverPageMismatchException(
                    "path is '"+urlpath()+
                    "' but '"+urlpaths[self.__class.__name__]+"' expected "+
                    "for "+self.__class.__name__
                )
            return page_method(self, *args, **kwargs)
        return _wrapper


class LoginPage(URLVerifyingPage):
    def __init__(self, username=username, password=password, baseurl="http://example.com/"):
        self.username = username
        self.password = password
        self.driver = initialize(baseurl)
        super(LoginPage, self).__init__()

    def login(self):
        driver.find_element_by_id("username").clear()
        driver.find_element_by_id("username").send_keys(self.username)
        driver.find_element_by_id("password").clear()
        driver.find_element_by_id("password").send_keys(self.password)
        driver.find_element_by_id("login_button").click()
        return HomePage()

class HomePage(URLVerifyingPage):
    def some_method(self):
        ...
        return SomePage()

    def many_more_methods(self):
        ...
        return ManyMorePages()

It's no big deal if a page gets instantiated a handful of times -- the methods will just get wrapped a handful of times and a handful of unnecessary checks will take place, but everything will still work. But it would be bad if a page was instantiated dozens or hundreds or tens of thousands of times. I could just put a flag in the class definition for each page and check to see if the methods have already been wrapped, but I like the idea of keeping the class definitions pure and clean and shoving all the hocus-pocus into a deep corner of my system where no one can see it and it just works.

3条回答
倾城 Initia
2楼-- · 2019-07-24 22:45

it sounds like you want to provide a __new__ implementation: Something like:

class MySingledtonBase(object):
    instance_cache = {}
    def __new__(cls, arg1, arg2):
        if cls in MySingletonBase.instance_cache:
            return MySingletonBase.instance_cache[cls]
        self = super(MySingletonBase, cls).__new__(arg1, arg2)
        MySingletonBase.instance_cache[cls] = self
        return self
查看更多
三岁会撩人
3楼-- · 2019-07-24 22:50

Rather than adding complex code to catch mistakes at runtime, I'd first try to use convention to guide users of your module to do the right thing on their own.

Give your classes "private" names (prefixed by an underscore), give them names that suggest they shouldn't be instantiated (eg _Internal...) and make your factory function "public".

That is, something like this:

class _InternalSubClassOne(_BaseClass):
    ...

class _InternalSubClassTwo(_BaseClass):
    ...

# An example factory function.
def new_object(arg):
    return _InternalSubClassOne() if arg == 'one' else _InternalSubClassTwo()

I'd also add docstrings or comments to each class, like "Don't instantiate this class by hand, use the factory method new_object."

查看更多
可以哭但决不认输i
4楼-- · 2019-07-24 22:52

In Python, it's almost never worth trying to "force" anything. Whatever you come up with, someone can get around it by monkeypatching your class, copying and editing the source, fooling around with bytecode, etc.

So, just write your factory, and document that as the right way to get an instance of your class, and expect anyone who writes code using your classes to understand TOOWTDI, and not violate it unless she really knows what she's doing and is willing to figure out and deal with the consequences.

If you're just trying to prevent accidents, rather than intentional "misuse", that's a different story. In fact, it's just standard design-by-contract: check the invariant. Of course at this point, SillyBaseClass is already screwed up, and it's too late to repair it, and all you can do is assert, raise, log, or whatever else is appropriate. But that's what you want: it's a logic error in the application, and the only thing to do is get the programmer to fix it, so assert is probably exactly what you want.

So:

class SillyBaseClass:
    singletons = {}

class Foo(SillyBaseClass):
    def __init__(self):
        assert self.__class__ not in SillyBaseClass.singletons

def get_foo():
    if Foo not in SillyBaseClass.singletons:
        SillyBaseClass.singletons[Foo] = Foo()
    return SillyBaseClass.singletons[Foo]

If you really do want to stop things from getting this far, you can check the invariant earlier, in the __new__ method, but unless "SillyBaseClass got screwed up" is equivalent to "launch the nukes", why bother?

查看更多
登录 后发表回答