Skip to content

Commit

Permalink
Merge pull request #103 from the-virtual-brain/constraints
Browse files Browse the repository at this point in the history
Constraints to state variable handling by integrator's schemes
  • Loading branch information
liadomide authored Oct 11, 2019
2 parents 6f2bb78 + 33d2fdf commit 53d6512
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 60 deletions.
4 changes: 2 additions & 2 deletions tvb/basic/traits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@
"""

import core
import traited_interface, traited_interface2
from . import core
from . import traited_interface, traited_interface2
from tvb.basic.profile import TvbProfile

# Add interfaces based on configured parameter on classes
Expand Down
104 changes: 80 additions & 24 deletions tvb/simulator/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ class Integrator(core.Type):
_base_classes = ['Integrator', 'IntegratorStochastic', 'RungeKutta4thOrderDeterministic']

dt = basic.Float(
label = "Integration-step size (ms)",
default = 0.01220703125, #0.015625,
label="Integration-step size (ms)",
default=0.01220703125, #0.015625,
#range = basic.Range(lo= 0.0048828125, hi=0.244140625, step= 0.1, base=2.)
required = True,
doc = """The step size used by the integration routine in ms. This
required=True,
doc="""The step size used by the integration routine in ms. This
should be chosen to be small enough for the integration to be
numerically stable. It is also necessary to consider the desired sample
period of the Monitors, as they are restricted to being integral
Expand All @@ -88,17 +88,27 @@ class Integrator(core.Type):
it is consitent with Monitors using sample periods corresponding to
powers of 2 from 128 to 4096Hz.""")

bounded_state_variable_indices = arrays.IntegerArray(
label="indices of the state variables to be bounded by the integrators "
"within the boundaries in the boundaries' values array",
default=None,
order=-1)

state_variable_boundaries = arrays.FloatArray(
label="The boundary values of the state variables",
default=None,
order=-1)

clamped_state_variable_indices = arrays.IntegerArray(
label = "indices of the state variables to be clamped by the integrators to the values in the clamped_values array",
default = None,
label="indices of the state variables to be clamped by the integrators to the values in the clamped_values array",
default=None,
order=-1)

clamped_state_variable_values = arrays.FloatArray(
label = "The values of the state variables which are clamped ",
default = None,
label="The values of the state variables which are clamped ",
default=None,
order=-1)


def scheme(self, X, dfun, coupling, local_coupling, stimulus):
"""
The scheme of integrator should take a state and provide the next
Expand All @@ -109,13 +119,22 @@ def scheme(self, X, dfun, coupling, local_coupling, stimulus):
msg = "Integrator is a base class; please use a suitable subclass."
raise NotImplementedError(msg)

def bound_state(self, X):
for sv_ind, sv_bounds in \
zip(self.bounded_state_variable_indices,
self.state_variable_boundaries):
if sv_bounds[0] is not None:
X[sv_ind][X[sv_ind] < sv_bounds[0]] = sv_bounds[0]
if sv_bounds[1] is not None:
X[sv_ind][X[sv_ind] > sv_bounds[1]] = sv_bounds[1]

def clamp_state(self, X):
if self.clamped_state_variable_values is not None:
X[self.clamped_state_variable_indices] = self.clamped_state_variable_values
X[self.clamped_state_variable_indices] = self.clamped_state_variable_values

def __str__(self):
return simple_gen_astr(self, 'dt')


class IntegratorStochastic(Integrator):
r"""
The IntegratorStochastic class is a base class for the stochastic
Expand Down Expand Up @@ -177,13 +196,19 @@ def scheme(self, X, dfun, coupling, local_coupling, stimulus):
"""
#import pdb; pdb.set_trace()
m_dx_tn = dfun(X, coupling, local_coupling)
inter = X + self.dt * (m_dx_tn + stimulus)
self.clamp_state(inter)
inter = X + self.dt * (m_dx_tn + stimulus)
if self.state_variable_boundaries is not None:
self.bound_state(inter)
if self.clamped_state_variable_values is not None:
self.clamp_state(inter)

dX = (m_dx_tn + dfun(inter, coupling, local_coupling)) * self.dt / 2.0

X_next = X + dX + self.dt * stimulus
self.clamp_state(X_next)
if self.state_variable_boundaries is not None:
self.bound_state(X_next)
if self.clamped_state_variable_values is not None:
self.clamp_state(X_next)
return X_next


Expand Down Expand Up @@ -221,12 +246,19 @@ def scheme(self, X, dfun, coupling, local_coupling, stimulus):
noise *= noise_gfun

inter = X + self.dt * m_dx_tn + noise + self.dt * stimulus
self.clamp_state(inter)
if self.state_variable_boundaries is not None:
self.bound_state(inter)
if self.clamped_state_variable_values is not None:
self.clamp_state(inter)

dX = (m_dx_tn + dfun(inter, coupling, local_coupling)) * self.dt / 2.0

X_next = X + dX + noise + self.dt * stimulus
self.clamp_state(X_next)
if self.state_variable_boundaries is not None:
self.bound_state(X_next)
if self.clamped_state_variable_values is not None:
self.clamp_state(X_next)

return X_next


Expand Down Expand Up @@ -254,7 +286,10 @@ def scheme(self, X, dfun, coupling, local_coupling, stimulus):
self.dX = dfun(X, coupling, local_coupling)

X_next = X + self.dt * (self.dX + stimulus)
self.clamp_state(X_next)
if self.state_variable_boundaries is not None:
self.bound_state(X_next)
if self.clamped_state_variable_values is not None:
self.clamp_state(X_next)
return X_next


Expand Down Expand Up @@ -288,7 +323,10 @@ def scheme(self, X, dfun, coupling, local_coupling, stimulus):
dX = dfun(X, coupling, local_coupling) * self.dt
noise_gfun = self.noise.gfun(X)
X_next = X + dX + noise_gfun * noise + self.dt * stimulus
self.clamp_state(X_next)
if self.state_variable_boundaries is not None:
self.bound_state(X_next)
if self.clamped_state_variable_values is not None:
self.clamp_state(X_next)
return X_next


Expand Down Expand Up @@ -326,19 +364,31 @@ def scheme(self, X, dfun, coupling, local_coupling=0.0, stimulus=0.0):

k1 = dfun(X, coupling, local_coupling)
inter_k1 = X + dt2 * k1
self.clamp_state(inter_k1)
if self.state_variable_boundaries is not None:
self.bound_state(inter_k1)
if self.clamped_state_variable_values is not None:
self.clamp_state(inter_k1)
k2 = dfun(inter_k1, coupling, local_coupling)
inter_k2 = X + dt2 * k2
self.clamp_state(inter_k2)
if self.state_variable_boundaries is not None:
self.bound_state(inter_k2)
if self.clamped_state_variable_values is not None:
self.clamp_state(inter_k2)
k3 = dfun(inter_k2, coupling, local_coupling)
inter_k3 = X + dt * k3
self.clamp_state(inter_k3)
if self.state_variable_boundaries is not None:
self.bound_state(inter_k3)
if self.clamped_state_variable_values is not None:
self.clamp_state(inter_k3)
k4 = dfun(inter_k3, coupling, local_coupling)

dX = dt6 * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

X_next = X + dX + self.dt * stimulus
self.clamp_state(X_next)
if self.state_variable_boundaries is not None:
self.bound_state(X_next)
if self.clamped_state_variable_values is not None:
self.clamp_state(X_next)
return X_next


Expand Down Expand Up @@ -401,15 +451,21 @@ class SciPyODE(SciPyODEBase):

def scheme(self, X, dfun, coupling, local_coupling, stimulus):
X_next = self._apply_ode(X, dfun, coupling, local_coupling, stimulus)
self.clamp_state(X_next)
if self.state_variable_boundaries is not None:
self.bound_state(X_next)
if self.clamped_state_variable_values is not None:
self.clamp_state(X_next)
return X_next

class SciPySDE(SciPyODEBase):

def scheme(self, X, dfun, coupling, local_coupling, stimulus):
X_next = self._apply_ode(X, dfun, coupling, local_coupling, stimulus)
X_next += self.noise.gfun(X) * self.noise.generate(X.shape)
self.clamp_state(X_next)
if self.state_variable_boundaries is not None:
self.bound_state(X_next)
if self.clamped_state_variable_values is not None:
self.clamp_state(X_next)
return X_next

class VODE(SciPyODE, Integrator):
Expand Down
16 changes: 16 additions & 0 deletions tvb/simulator/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Model(core.Type):
_nvar = None
number_of_modes = 1
cvar = None
state_variable_boundaries = None

def _build_observer(self):
template = ("def observe(state):\n"
Expand All @@ -85,6 +86,21 @@ def configure(self):
super(Model, self).configure()
self.update_derived_parameters()
self._build_observer()
# Make sure that if there are any state variable boundaries, ...
if isinstance(self.state_variable_boundaries, dict):
for sv, bounds in self. state_variable_boundaries.items():
try:
# ...the boundaries correspond to model's state variables,
self.state_variables.index(sv)
except:
# TODO: Add the correct type of error and error message
raise
# and for every two sided constraint, the left boundary is lower than the right one
if bounds[0] is not None and bounds[1] is not None:
assert bounds[0] <= bounds[1]
elif self.state_variable_boundaries is not None:
# TODO: Add here a warning or, even, error?
self.state_variable_boundaries = None

@property
def nvar(self):
Expand Down
25 changes: 13 additions & 12 deletions tvb/simulator/models/wong_wang.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,9 @@
@guvectorize([(float64[:],)*11], '(n),(m)' + ',()'*8 + '->(n)', nopython=True)
def _numba_dfun(S, c, a, b, d, g, ts, w, j, io, dx):
"Gufunc for reduced Wong-Wang model equations."

if S[0] < 0.0:
dx[0] = 0.0 - S[0]
elif S[0] > 1.0:
dx[0] = 1.0 - S[0]
else:
x = w[0]*j[0]*S[0] + io[0] + j[0]*c[0]
h = (a[0]*x - b[0]) / (1 - numpy.exp(-d[0]*(a[0]*x - b[0])))
dx[0] = - (S[0] / ts[0]) + (1.0 - S[0]) * h * g[0]
x = w[0]*j[0]*S[0] + io[0] + j[0]*c[0]
h = (a[0]*x - b[0]) / (1 - numpy.exp(-d[0]*(a[0]*x - b[0])))
dx[0] = - (S[0] / ts[0]) + (1.0 - S[0]) * h * g[0]


class ReducedWongWang(ModelNumbaDfun):
Expand Down Expand Up @@ -144,13 +138,21 @@ class ReducedWongWang(ModelNumbaDfun):
order=9
)

# Used for phase-plane axis ranges and to bound random initial() conditions.
state_variable_boundaries = basic.Dict(
label="State Variable boundaries [lo, hi]",
default={"S": numpy.array([0.0, 1.0])},
doc="""The values for each state-variable should be set to encompass
the boundaries of the dynamic range of that state-variable. Set None for one-sided boundaries""",
order=10)

variables_of_interest = basic.Enumerate(
label="Variables watched by Monitors",
options=["S"],
default=["S"],
select_multiple=True,
doc="""default state variables to be monitored""",
order=10)
order=11)

state_variables = ['S']
_nvar = 1
Expand All @@ -172,8 +174,7 @@ def _numpy_dfun(self, state_variables, coupling, local_coupling=0.0):
"""
S = state_variables[0, :]
S[S<0] = 0.
S[S>1] = 1.

c_0 = coupling[0, :]


Expand Down
37 changes: 15 additions & 22 deletions tvb/simulator/models/wong_wang_exc_io_inh_i.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,17 @@ def _numba_dfun(S, c, ae, be, de, ge, te, wp, we, jn, ai, bi, di, gi, ti, wi, ji

cc = g[0]*jn[0]*c[0]

if S[0] < 0.0:
S_e = 0.0 # - S[0] # TODO: clarify the boundary to be reflective or saturated!!!
elif S[0] > 1.0:
S_e = 1.0 # - S[0] # TODO: clarify the boundary to be reflective or saturated!!!
else:
S_e = S[0]

if S[1] < 0.0:
S_i = 0.0 # - S[1] TODO: clarify the boundary to be reflective or saturated!!!
elif S[1] > 1.0:
S_i = 1.0 # - S[1] TODO: clarify the boundary to be reflective or saturated!!!
else:
S_i = S[1]

jnSe = jn[0]*S_e
x = wp[0]*jnSe - ji[0]*S_i + we[0]*io[0] + cc
jnSe = jn[0]*S[0]

x = wp[0]*jnSe - ji[0]*S[1] + we[0]*io[0] + cc
x = ae[0]*x - be[0]
h = x / (1 - numpy.exp(-de[0]*x))
dx[0] = - (S_e / te[0]) + (1.0 - S_e) * h * ge[0]
dx[0] = - (S[0] / te[0]) + (1.0 - S[0]) * h * ge[0]

x = jnSe - S_i + wi[0]*io[0] + l[0]*cc
x = jnSe - S[1] + wi[0]*io[0] + l[0]*cc
x = ai[0]*x - bi[0]
h = x / (1 - numpy.exp(-di[0]*x))
dx[1] = - (S_i / ti[0]) + h * gi[0]
dx[1] = - (S[1] / ti[0]) + h * gi[0]


class ReducedWongWangExcIOInhI(ModelNumbaDfun):
Expand Down Expand Up @@ -234,13 +221,21 @@ class ReducedWongWangExcIOInhI(ModelNumbaDfun):
order=22
)

# Used for phase-plane axis ranges and to bound random initial() conditions.
state_variable_boundaries = basic.Dict(
label="State Variable boundaries [lo, hi]",
default={"S_e": numpy.array([0.0, 1.0]), "S_i": numpy.array([0.0, 1.0])},
doc="""The values for each state-variable should be set to encompass
the boundaries of the dynamic range of that state-variable. Set None for one-sided boundaries""",
order=23)

variables_of_interest = basic.Enumerate(
label="Variables watched by Monitors",
options=['S_e', 'S_i'],
default=['S_e', 'S_i'],
select_multiple=True,
doc="""default state variables to be monitored""",
order=23)
order=24)

state_variables = ['S_e', 'S_i']
_nvar = 2
Expand All @@ -266,8 +261,6 @@ def _numpy_dfun(self, state_variables, coupling, local_coupling=0.0):
"""
S = state_variables[:, :]
S[S < 0] = 0.
S[S > 1] = 1.

c_0 = coupling[0, :]

Expand Down
14 changes: 14 additions & 0 deletions tvb/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ def preconfigure(self):
self.coupling.configure()
self.model.configure()
self.integrator.configure()
if isinstance(self.model.state_variable_constraint, dict):
indices = []
boundaries = []
for sv, sv_bounds in self.model.state_variable_constraint.items():
indices.append(self.model.state_variables.index(sv))
boundaries.append(sv_bounds)
sort_inds = numpy.argsort(indices)
self.integrator.constraint_state_variable_indices = numpy.array(indices)[sort_inds]
self.integrator.constraint_state_variable_boundaries = numpy.array(boundaries)[sort_inds]
else:
self.integrator.constraint_state_variable_indices = None
self.integrator.constraint_state_variable_boundaries = None
# monitors needs to be a list or tuple, even if there is only one...
if not isinstance(self.monitors, (list, tuple)):
self.monitors = [self.monitors]
Expand Down Expand Up @@ -465,6 +477,8 @@ def _configure_history(self, initial_conditions):
history[:ic_shape[0], :, :, :] = initial_conditions
history = numpy.roll(history, shift, axis=0)
self.current_step += ic_shape[0] - 1
if self.integrator.state_variable_boundaries is not None:
self.integrator.bound_state(numpy.swapaxes(history, 0, 1))
LOG.info('Final initial history shape is %r', history.shape)
# create initial state from history
self.current_state = history[self.current_step % self.horizon].copy()
Expand Down
Loading

0 comments on commit 53d6512

Please sign in to comment.