Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions desolver/backend/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import numpy

__all__ = [
'linear_algebra_exceptions',
'e',
'euler_gamma',
'pi',
]

linear_algebra_exceptions = [numpy.linalg.LinAlgError]

# Constants
e = 2.7182818284590452353602874713526624977572470936999595749669676277240766303535475945713821785251664274
euler_gamma = 0.5772156649015328606065120900824024310421593359399235988057672348848677267776646709369470632917467495
Expand Down
1 change: 1 addition & 0 deletions desolver/backend/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import autoray

linear_algebra_exceptions.append(torch._C._LinAlgError)

def __solve_linear_system(A, b, sparse=False):
__A = A
Expand Down
84 changes: 77 additions & 7 deletions desolver/differential_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from desolver import utilities as deutil

import numpy as np
import inspect
from scipy.integrate._ivp.ivp import OdeResult

CubicHermiteInterp = deutil.interpolation.CubicHermiteInterp
root_finder = deutil.optimizer.brentsrootvec
Expand All @@ -16,7 +18,8 @@
__all__ = [
'DiffRHS',
'rhs_prettifier',
'OdeSystem'
'OdeSystem',
'solve_ivp'
]

StateTuple = collections.namedtuple('StateTuple', ['t', 'y', 'event'])
Expand Down Expand Up @@ -539,7 +542,7 @@ def __init__(self, equ_rhs, y0, t=(0, 1), dense_output=False, dt=1.0, rtol=None,
self.__allocate_soln_space(self.__alloc_space_steps(self.tf))
self.__fix_dt_dir(self.tf, self.t0)
self.__events = []
self.initialise_integrator(preserve_states=False)
self.initialise_integrator(preserve_states=False)

@property
def sol(self):
Expand Down Expand Up @@ -978,7 +981,7 @@ def integrate(self, t=None, callback=None, eta=False, events=None):
behaviour of the system is highly unreliable. This could be due to numerical issues.

"""
if t:
if t is not None:
tf = t
else:
tf = self.tf
Expand Down Expand Up @@ -1023,11 +1026,15 @@ def integrate(self, t=None, callback=None, eta=False, events=None):
end_int = False
self.__allocate_soln_space(total_steps)
try:
while (implicit_integration or self.dt != 0 and D.ar_numpy.abs(tf - self.__t[self.counter]) >= D.tol_epsilon(self.__y[self.counter].dtype)) and not end_int:
while (implicit_integration or (self.dt != 0 and D.ar_numpy.abs(tf - self.__t[self.counter]) >= D.tol_epsilon(self.__y[self.counter].dtype))) and not end_int:
if not implicit_integration and D.ar_numpy.abs(self.dt + self.__t[self.counter]) > D.ar_numpy.abs(tf):
self.dt = (tf - self.__t[self.counter])
self.dt, (dTime, dState) = self.integrator(self.equ_rhs, self.__t[self.counter], self.__y[self.counter],
self.constants, timestep=self.dt)
is_final_step = True
dt = (tf - self.__t[self.counter])
else:
is_final_step = False
dt = self.dt
new_dt, (dTime, dState) = self.integrator(self.equ_rhs, self.__t[self.counter], self.__y[self.counter],
self.constants, timestep=dt)

if self.counter + 1 >= len(self.__y):
total_steps = self.__alloc_space_steps(tf - dTime) + 1
Expand Down Expand Up @@ -1092,6 +1099,9 @@ def integrate(self, t=None, callback=None, eta=False, events=None):
self.__sol.remove_interpolant(0)

steps += 1

if not is_final_step:
self.dt = new_dt

for i in callback:
i(self)
Expand Down Expand Up @@ -1127,6 +1137,7 @@ def __repr__(self):
return "\n".join([
"""{:>10}: {:<128}""".format("message", self.integration_status),
"""{:>10}: {:<128}""".format("nfev", str(self.nfev)),
"""{:>10}: {:<128}""".format("njev", str(self.njev)),
"""{:>10}: {:<128}""".format("sol", str(self.sol)),
"""{:>10}: {:<128}""".format("t0", str(self.t0)),
"""{:>10}: {:<128}""".format("tf", str(self.tf)),
Expand All @@ -1144,6 +1155,7 @@ def _repr_markdown_(self):
{:>10}: {:<128}
{:>10}: {:<128}
{:>10}: {:<128}
{:>10}: {:<128}
{:>10}:
```
{}
Expand All @@ -1155,6 +1167,7 @@ def _repr_markdown_(self):
""".format(
"message", self.integration_status,
"nfev", str(self.nfev),
"njev", str(self.njev),
"sol", str(self.sol),
"t0", str(self.t0),
"tf", str(self.tf),
Expand Down Expand Up @@ -1217,3 +1230,60 @@ def __getitem__(self, index):

def __len__(self):
return self.counter + 1


def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False,
events=None, vectorized=False, args=None, **options):
"""
Drop-in replacement for `scipy.integrate.solve_ivp`, provides a functional
interface to the `desolver` integration routines.
"""

constants = None
if args is not None:
fn = fun
while isinstance(fn, DiffRHS):
fn = fn.rhs
fn_args_kwargs = inspect.getfullargspec(fn)
constants = {key:value for key,value in zip(fn_args_kwargs[0][2:], args)}

max_step = options.get("max_step", np.inf)
min_step = options.get("min_step", 0.0)

initial_dt = options.get('first_step', 1.0)
initial_dt = D.ar_numpy.minimum(initial_dt, max_step)
initial_dt = D.ar_numpy.maximum(initial_dt, min_step)

ode_system = OdeSystem(equ_rhs=fun, y0=y0, t=t_span, dense_output=dense_output, dt=initial_dt,
atol=options.get('atol', None), rtol=options.get('rtol', None), constants=constants)

ode_system.method = method
callbacks = list(options.get("callbacks", []))
if "max_step" in options or "min_step" in options:
def __step_cb(ode_sys):
ode_sys.dt = D.ar_numpy.clip(ode_sys.dt, min=min_step, max=max_step)
callbacks.append(__step_cb)

integration_options = dict(callback=callbacks, events=events, eta=options.get("show_prog_bar", False))
if t_eval is None:
ode_system.integrate(**integration_options)
t_res = ode_system.t
y_res = D.ar_numpy.transpose(ode_system.y, axes=[*range(1, len(ode_system.y.shape)), 0])
else:
t_eval = D.ar_numpy.sort(t_eval)
if t_eval[0] < t_span[0] or t_eval[-1] > t_span[1]:
raise ValueError(f"Expected `t_eval` to be in the range [{t_span[0]}, {t_span[1]}]")
t_res = []
y_res = []
for t in t_eval:
ode_system.integrate(t=t, **integration_options)
t_res.append(ode_system[-1].t)
y_res.append(ode_system[-1].y)
t_res = D.ar_numpy.stack(t_res, axis=0)
y_res = D.ar_numpy.stack(y_res, axis=-1)

return OdeResult(t=t_res, y=y_res, sol=ode_system.sol, t_events=ode_system.events,
y_events=ode_system.events, nfev=ode_system.nfev, njev=ode_system.njev,
status=ode_system.integration_status, message=ode_system.integration_status,
success=ode_system.success, ode_system=ode_system)

11 changes: 5 additions & 6 deletions desolver/integrators/integrator_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@ def update_timestep(self, ignore_custom_adaptation=False):
dState = self.solver_dict['dState']
order = self.solver_dict['order']
if "system_scaling" in self.solver_dict:
system_scaling = 0.8 * self.solver_dict["system_scaling"]
system_scaling = 0.2 * D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep))
self.solver_dict["system_scaling"] = 0.8 * self.solver_dict["system_scaling"] + 0.2 * D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep))
else:
system_scaling = D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep))
self.solver_dict["system_scaling"] = system_scaling
total_error_tolerance = (atol + rtol * system_scaling)
epsilon_current = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(diff / total_error_tolerance))
self.solver_dict["system_scaling"] = D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep))
total_error_tolerance = (atol + rtol * self.solver_dict["system_scaling"])
with D.numpy.errstate(divide='ignore'):
epsilon_current = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(diff / total_error_tolerance))
if "epsilon_last" in self.solver_dict:
epsilon_last = self.solver_dict["epsilon_last"]
else:
Expand Down
Loading