I am trying to solve a differential equation with Python.
In this two system differential equation if the value of first variable (v
) is more than a threshold (30) it should be reset to another value (-65). Below I put my code. The problem is that the value of first variable after reaching 30 remains constant and won't reset to -65. These equations describe the dynamics of a single neuron. The equations are taken from this website and this PDF file.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from scipy.integrate import odeint
plt.close('all')
a = 0.02
b = 0.2
c = -65
d = 8
i = 0
p = [a,b,c,d,i]
def fun(u,tspan,*p):
du = [0,0]
if u[0] < 30: #Checking if the threshold has been reached
du[0] = (0.04*u[0] + 5)*u[0] + 150 - u[1] - p[4]
du[1] = p[0]*(p[1]*u[0]-u[1])
else:
u[0] = p[2] #reset to -65
u[1] = u[1] + p[3]
return du
p = tuple(p)
y0 = [0,0]
tspan = np.linspace(0,100,1000)
sol = odeint(fun, y0, tspan, args=p)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
plt.plot(tspan,sol[:,0],'k',linewidth = 5)
plt.plot(tspan,sol[:,1],'r',linewidth = 5)
myleg = plt.legend(['v','u'],\
loc='upper right',prop = {'size':28,'weight':'bold'}, bbox_to_anchor=(1,0.9))
The solution looks like:
Here is the correct solution by Julia
, here u1
represent v
:
This is the Julia
code:
using DifferentialEquations
using Plots
a = 0.02
b = 0.2
c = -65
d = 8
i = 0
p = [a,b,c,d,i]
function fun(du,u,p,t)
if u[1] <30
du[1] = (0.04*u[1] + 5)*u[1] + 150 - u[2] - p[5]
du[2] = p[1]*(p[2]*u[1]-u[2])
else
u[1] = p[3]
u[2] = u[2] + p[4]
end
end
u0 = [0.0;0.0]
tspan = (0.0,100)
prob = ODEProblem(fun,u0,tspan,p)
tic()
sol = solve(prob,reltol = 1e-8)
toc()
plot(sol)
Recommended solution
This uses events and integrates separately after each discontinuity.
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp
a = 0.02
b = 0.2
c = -65
d = 8
i = 0
p = [a,b,c,d,i]
# Define event function and make it a terminal event
def event(t, u):
return u[0] - 30
event.terminal = True
# Define differential equation
def fun(t, u):
du = [(0.04*u[0] + 5)*u[0] + 150 - u[1] - p[4],
p[0]*(p[1]*u[0]-u[1])]
return du
u = [0,0]
ts = []
ys = []
t = 0
tend = 100
while True:
sol = solve_ivp(fun, (t, tend), u, events=event)
ts.append(sol.t)
ys.append(sol.y)
if sol.status == 1: # Event was hit
# New start time for integration
t = sol.t[-1]
# Reset initial state
u = sol.y[:, -1].copy()
u[0] = p[2] #reset to -65
u[1] = u[1] + p[3]
else:
break
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
# We have to stitch together the separate simulation results for plotting
ax.plot(np.concatenate(ts), np.concatenate(ys, axis=1).T)
myleg = plt.legend(['v','u'])
Minimum change "solution"
It appears as though your approach works just fine with solve_ivp
.
Warning I think in both Julia and solve_ivp
, the correct way to handle this kind of thing is to use events. I believe the approach below relies on an implementation detail, which is that the state vector passed to the function is the same object as the internal state vector, which allows us to modify it in place. If it were a copy, this approach wouldn't work. In addition, there is no guarantee in this approach that the solver is taking small enough steps that the correct point where the limit is reached will be stepped on. Using events will make this more correct and generalisable to other differential equations which perhaps have lower gradients before the discontinuity.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FormatStrFormatter
from scipy.integrate import solve_ivp
plt.close('all')
a = 0.02
b = 0.2
c = -65
d = 8
i = 0
p = [a,b,c,d,i]
def fun(t, u):
du = [0,0]
if u[0] < 30: #Checking if the threshold has been reached
du[0] = (0.04*u[0] + 5)*u[0] + 150 - u[1] - p[4]
du[1] = p[0]*(p[1]*u[0]-u[1])
else:
u[0] = p[2] #reset to -65
u[1] = u[1] + p[3]
return du
y0 = [0,0]
tspan = (0,100)
sol = solve_ivp(fun, tspan, y0)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
plt.plot(sol.t,sol.y[0, :],'k',linewidth = 5)
plt.plot(sol.t,sol.y[1, :],'r',linewidth = 5)
myleg = plt.legend(['v','u'],loc='upper right',prop = {'size':28,'weight':'bold'}, bbox_to_anchor=(1,0.9))
Result