I am using PyMC3
to calculate something which I won't get into here but you can get the idea from this link if interested.
The '2-lambdas' case is basically a switch function, which needs to be compiled to a Theano
function to avoid dtype
errors and looks like this:
import theano
from theano.tensor import lscalar, dscalar, lvector, dvector, argsort
@theano.compile.ops.as_op(itypes=[lscalar, dscalar, dscalar], otypes=[dvector])
def lambda_2_distributions(tau, lambda_1, lambda_2):
"""
Return values of `lambda_` for each observation based on the
transition value `tau`.
"""
out = zeros(num_observations)
out[: tau] = lambda_1 # lambda before tau is lambda1
out[tau:] = lambda_2 # lambda after (and including) tau is lambda2
return out
I am trying to generalize this to apply to 'n-lambdas', where taus.shape[0] = lambdas.shape[0] - 1
, but I can only come up with this horribly slow numpy
implementation.
@theano.compile.ops.as_op(itypes=[lvector, dvector], otypes=[dvector])
def lambda_n_distributions(taus, lambdas):
out = zeros(num_observations)
np_tau_indices = argsort(taus).eval()
num_taus = taus.shape[0]
for t in range(num_taus):
if t == 0:
out[: taus[np_tau_indices[t]]] = lambdas[t]
elif t == num_taus - 1:
out[taus[np_tau_indices[t]]:] = lambdas[t + 1]
else:
out[taus[np_tau_indices[t]]: taus[np_tau_indices[t + 1]]] = lambdas[t]
return out
Any ideas on how to speed this up using pure Theano
(avoiding the call to .eval()
)? It's been a few years since I've used it and so don't know the right approach.
Using a switch function is not recommended, as it breaks the nice geometry of the parameters space and makes sampling using modern sampler like NUTS difficult.
Instead, you can try model it using a continuous relaxation of a switch function. The main idea here would be to model the rate before the first switch point as a baseline; and add the prediction from a logistic function after each switch point:
There are a few tricks I used here as well, for example, the composite transformation that is not on the PyMC3 code base yet. You can have a look at the full code here: https://gist.github.com/junpenglao/f7098c8e0d6eadc61b3e1bc8525dd90d
If you have more question, please post to https://discourse.pymc.io with your model and (simulated) data. I check and answer on the PyMC3 discourse much more regularly.