Optimize constants in differential equations in Py

2019-01-23 09:44发布

Okay so how would i approach to writing a code to optimize the constants a and b in a differential equation, like dy/dt = a*y^2 + b, using curve_fit? I would be using odeint to solve the ODE and then curve_fit to optimize a and b. If you could please provide input on this situation i would greatly appreciate it!

3条回答
在下西门庆
2楼-- · 2019-01-23 10:29

You might be better served by looking at ODEs with Sympy. Scipy/Numpy are fundamentally numerical packages and aren't really set up to do algebraic/symbolic operations.

查看更多
三岁会撩人
3楼-- · 2019-01-23 10:31

To address specifically this type of problem, I decided to write a wrapper package which unifies sympy and scipy. It's called symfit. Fitting to your ODE would then look like this:

tdata = np.array([10, 26, 44, 70, 120])
ydata = 10e-4 * np.array([44, 34, 27, 20, 14])
y, t = variables('y, t')
a, b = parameters('a, b')

model_dict = {
    D(y, t): a*y^2 + b
}

ode_model = ODEModel(model_dict, initial={t: 0.0, y: 0.0})

fit = Fit(ode_model, t=tdata, y=ydata)
fit_result = fit.execute()

As you can see from the way it is defined as a dict, fitting to systems of (first order) ODEs is no problem. Check out the docs for more!

查看更多
闹够了就滚
4楼-- · 2019-01-23 10:47

You definitely can do this:

import numpy as np
from scipy.integrate import odeint
from scipy.optimize import curve_fit

def f(y, t, a, b):
    return a*y**2 + b

def y(t, a, b, y0):
    """
    Solution to the ODE y'(t) = f(t,y,a,b) with initial condition y(0) = y0
    """
    y = odeint(f, y0, t, args=(a, b))
    return y.ravel()

# Some random data to fit
data_t = np.sort(np.random.rand(200) * 10)
data_y = data_t**2 + np.random.rand(200)*10

popt, cov = curve_fit(y, data_t, data_y, [-1.2, 0.1, 0])
a_opt, b_opt, y0_opt = popt

print("a = %g" % a_opt)
print("b = %g" % b_opt)
print("y0 = %g" % y0_opt)

import matplotlib.pyplot as plt
t = np.linspace(0, 10, 2000)
plt.plot(data_t, data_y, '.',
         t, y(t, a_opt, b_opt, y0_opt), '-')
plt.gcf().set_size_inches(6, 4)
plt.savefig('out.png', dpi=96)
plt.show()

查看更多
登录 后发表回答