How to write a wrapper to fix arbitrary parameters

2019-08-16 02:03发布

问题:

I would like to write a curve-fitting script that allows me to fix parameters of a function of the form:

def func(x, *p):
    assert len(p) % 2 == 0
    fval = 0
    for j in xrange(0, len(p), 2):
        fval += p[j]*np.exp(-p[j+1]*t)
    return fval

For example, let's say I want p = [p1, p2, p3, p4], and I want p2 and p3 to be constant A and B (going from a 4-parameter fit to a 2-parameter fit). I understand that functools.partial doesn't let me do this which is why I want to write my own wrapper. But I am having a bit of trouble doing so. This is what I have so far:

def fix_params(f, t, pars, fix_pars):
    # fix_pars = ((ind1, A), (ind2, B))
    new_pars = [None]*(len(pars) + len(fix_pars))
    for ind, fix in fix_pars:
        new_pars[ind] = fix
    for par in pars:
        for j, npar in enumerate(new_pars):
            if npar == None:
                new_pars[j] = par
                break
    assert None not in new_pars
    return f(t, *new_pars)

The problem with this I think is that, scipy.optimize.curve_fit won't work well with a function passed through this kind of wrapper. How should I get around this?

回答1:

Sounds like what you want to do is currying? In Python, you can do this with inner functions.

Example:

def foo(x):
    def bar(y):
        return x + y
    return bar

bar = foo(3)
print(type(bar))    # a function (of one variable with the other fixed to 3)
print(bar(8))       # 11
bar = foo(9)
print(bar(8))       # 17

In this way we can fix x in the function x + y. You can also put this into a decorator.

Here's a blog post someone wrote on doing this: https://mtomassoli.wordpress.com/2012/03/18/currying-in-python/

Regarding what will play nice with external libraries, the function foo here will return a function. In Python functions are first-class objects. So anything you give the returned function to will just see it as a function.



回答2:

So I think I have something workable. Maybe there is a way to improve on this.

Here is my code (without all the exception handling):

def func(x, *p):
    fval = 0
    for j in xrange(0, len(p), 2):
        fval += p[j]*np.exp(-p[j+1]*x)
    return fval

def fix_params(f, fix_pars):
    # fix_pars = ((1, A), (2, B))
    def new_func(x, *pars):
        new_pars = [None]*(len(pars) + len(fix_pars))
        for j, fp in fix_pars:
            new_pars[j] = fp
        for par in pars:
            for j, npar in enumerate(new_pars):
                if npar is None:
                    new_pars[j] = par
                    break
        return f(x, *new_pars)
    return new_func

p1 = [1, 0.5, 0.1, 1.2]
pfix = ((1, 0.5), (2, 0.1))
p2 = [1, 1.2]

new_func = fix_params(func, pfix)

x = np.arange(10)
dat1 = func(x, *p1)
dat2 = new_func(x, *p2)

if (dat1==dat2).all()
    print "ALL GOOD"