From 0ee3e29c912d9af54d9bdea4418e14806fff758c Mon Sep 17 00:00:00 2001 From: Ekin Ozturk Date: Thu, 6 Mar 2025 00:42:40 +0000 Subject: [PATCH 1/4] Added new `solve_ivp` interface to match scipy's interface, additionally included further performance improvements --- desolver/differential_system.py | 52 +++++++++- desolver/integrators/integrator_template.py | 11 +-- desolver/integrators/integrator_types.py | 103 ++++++++++++-------- desolver/integrators/utilities.py | 18 ++-- desolver/tests/test_differential_system.py | 61 ++++++++++-- desolver/tests/test_event_detection.py | 1 - desolver/utilities/optimizer.py | 26 +++-- 7 files changed, 199 insertions(+), 73 deletions(-) diff --git a/desolver/differential_system.py b/desolver/differential_system.py index 7fbb539..5a367d9 100644 --- a/desolver/differential_system.py +++ b/desolver/differential_system.py @@ -8,6 +8,8 @@ from desolver import utilities as deutil import numpy as np +import inspect +from scipy.integrate._ivp.ivp import OdeResult CubicHermiteInterp = deutil.interpolation.CubicHermiteInterp root_finder = deutil.optimizer.brentsrootvec @@ -16,7 +18,8 @@ __all__ = [ 'DiffRHS', 'rhs_prettifier', - 'OdeSystem' + 'OdeSystem', + 'solve_ivp' ] StateTuple = collections.namedtuple('StateTuple', ['t', 'y', 'event']) @@ -539,7 +542,7 @@ def __init__(self, equ_rhs, y0, t=(0, 1), dense_output=False, dt=1.0, rtol=None, self.__allocate_soln_space(self.__alloc_space_steps(self.tf)) self.__fix_dt_dir(self.tf, self.t0) self.__events = [] - self.initialise_integrator(preserve_states=False) + self.initialise_integrator(preserve_states=False) @property def sol(self): @@ -1217,3 +1220,48 @@ def __getitem__(self, index): def __len__(self): return self.counter + 1 + + +def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False, + events=None, vectorized=False, args=None, **options): + """ + Drop-in replacement for `scipy.integrate.solve_ivp`, provides a functional + interface to the `desolver` integration routines. + """ + + constants = None + if args is not None: + fn = fun + while isinstance(fn, DiffRHS): + fn = fn.rhs + fn_args_kwargs = inspect.getfullargspec(fn) + constants = {key:value for key,value in zip(fn_args_kwargs[0][2:], args)} + + ode_system = OdeSystem(equ_rhs=fun, y0=y0, t=t_span, dense_output=dense_output, dt=options.get('first_step', 1.0), + atol=options.get('atol', None), rtol=options.get('rtol', None), constants=constants) + + ode_system.method = method + callbacks = list(options.get("callbacks", [])) + if "max_step" in options or "min_step" in options: + max_step = options.get("max_step", np.inf) + min_step = options.get("min_step", 0.0) + callbacks.append(lambda ode_sys: D.ar_numpy.clip(ode_sys, min=min_step, max=max_step)) + + integration_options = dict(callback=callbacks, events=events, eta=options.get("show_prog_bar", False)) + if t_eval is None: + ode_system.integrate(**integration_options) + else: + t_eval = D.ar_numpy.sort(t_eval) + if t_eval[0] < t_span[0] or t_eval[-1] > t_span[1]: + raise ValueError(f"Expected `t_eval` to be in the range [{t_span[0]}, {t_span[1]}]") + for t in t_eval: + ode_system.integrate(t=t, **integration_options) + if t_span[1] > t_eval[-1]: + ode_system.integrate(t=t_span[1], **integration_options) + + yres = ode_system.y + return OdeResult(t=ode_system.t, y=D.ar_numpy.transpose(yres, axes=[*range(1, len(yres.shape)), 0]), sol=ode_system.sol, t_events=ode_system.events, + y_events=ode_system.events, nfev=ode_system.nfev, njev=ode_system.njev, + status=ode_system.integration_status, message=ode_system.integration_status, + success=ode_system.success, ode_system=ode_system) + \ No newline at end of file diff --git a/desolver/integrators/integrator_template.py b/desolver/integrators/integrator_template.py index c609791..f00d617 100644 --- a/desolver/integrators/integrator_template.py +++ b/desolver/integrators/integrator_template.py @@ -55,13 +55,12 @@ def update_timestep(self, ignore_custom_adaptation=False): dState = self.solver_dict['dState'] order = self.solver_dict['order'] if "system_scaling" in self.solver_dict: - system_scaling = 0.8 * self.solver_dict["system_scaling"] - system_scaling = 0.2 * D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep)) + self.solver_dict["system_scaling"] = 0.8 * self.solver_dict["system_scaling"] + 0.2 * D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep)) else: - system_scaling = D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep)) - self.solver_dict["system_scaling"] = system_scaling - total_error_tolerance = (atol + rtol * system_scaling) - epsilon_current = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(diff / total_error_tolerance)) + self.solver_dict["system_scaling"] = D.ar_numpy.maximum(D.ar_numpy.abs(initial_state), D.ar_numpy.abs(dState / timestep)) + total_error_tolerance = (atol + rtol * self.solver_dict["system_scaling"]) + with D.numpy.errstate(divide='ignore'): + epsilon_current = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(diff / total_error_tolerance)) if "epsilon_last" in self.solver_dict: epsilon_last = self.solver_dict["epsilon_last"] else: diff --git a/desolver/integrators/integrator_types.py b/desolver/integrators/integrator_types.py index 183f53a..289abe2 100644 --- a/desolver/integrators/integrator_types.py +++ b/desolver/integrators/integrator_types.py @@ -4,7 +4,7 @@ from desolver import exception_types from desolver.integrators import utilities as integrator_utilities from desolver.integrators import components -import math +from desolver.utilities.optimizer import broyden_update_jac import abc __all__ = [ @@ -133,6 +133,7 @@ def __init__(self, sys_dim, dtype, rtol=None, atol=None, device=None): self._explicit = all([(self.tableau_intermediate[col, col + 1:] == 0.0).all() for col in range(self.tableau_intermediate.shape[0])]) self._explicit_stages = [col for col in range(self.stages) if D.ar_numpy.all(self.tableau_intermediate[col, col + 1:] == 0.0)] self._implicit_stages = [col for col in range(self.stages) if D.ar_numpy.any(self.tableau_intermediate[col, col + 1:] != 0.0)] + self._requires_high_precision = False solver_dict_preserved = dict(safety_factor=0.8, order=self.order, atol=self.atol, rtol=self.rtol, redo_count=0) self.solver_dict = dict() @@ -145,11 +146,12 @@ def __init__(self, sys_dim, dtype, rtol=None, atol=None, device=None): )) if not self._explicit: solver_dict_preserved.update(dict( - tau0=2, tau1=2, niter0=0, niter1=0 + tau0=2, tau1=2, niter0=0, niter1=0, newton_prec0=0.0, newton_prec1=0.0 )) self.solver_dict.update(solver_dict_preserved) self.adaptation_fn = integrator_utilities.implicit_aware_update_timestep self.__jac_eye = None + self.__rhs_jac = None self.solver_dict_keep_keys = set(solver_dict_preserved.keys()) def __call__(self, rhs, initial_time, initial_state, constants, timestep): @@ -158,13 +160,24 @@ def __call__(self, rhs, initial_time, initial_state, constants, timestep): self.initial_time = D.ar_numpy.copy(initial_time) self.initial_rhs = None - if self.is_fsal and self.final_rhs is not None: - self.stage_values[...,0] = self.final_rhs + if self.final_rhs is not None: self.initial_rhs = self.final_rhs - self.final_rhs = None + if self.is_fsal: + self.stage_values[...,0] = self.final_rhs + else: + self.initial_rhs = rhs(initial_time, initial_state, **constants) - self.step(rhs=rhs, initial_time=initial_time, initial_state=initial_state, - constants=constants, timestep=timestep) + if self.is_implicit and self.__rhs_jac is None: + self.__rhs_jac = rhs.jac(initial_time, initial_state, **constants) + + try: + timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, + timestep) + except (D.numpy.linalg.LinAlgError, ValueError): + self._requires_high_precision = True + timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, + timestep) + self._requires_high_precision = False if self.is_adaptive or self.is_implicit: self.solver_dict['redo_count'] = 0 @@ -176,15 +189,26 @@ def __call__(self, rhs, initial_time, initial_state, constants, timestep): self.solver_dict['rtol'] = self.rtol self.solver_dict['dState'] = self.dState timestep, redo_step = self.update_timestep() + if self.is_implicit and not self.solver_dict.get("newton_iteration_success"): + redo_step = True + timestep = timestep * 0.8 if redo_step: for _ in range(64): self.solver_dict['redo_count'] += 1 - timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, - timestep) + try: + timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, + timestep) + except (D.numpy.linalg.LinAlgError, ValueError): + self._requires_high_precision = True + timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, + timestep) self.solver_dict['diff'] = timestep * self.get_error_estimate() self.solver_dict['timestep'] = self.dTime self.solver_dict['dState'] = self.dState timestep, redo_step = self.update_timestep() + if self.is_implicit and not self.solver_dict.get("newton_iteration_success"): + redo_step = True + timestep = timestep * 0.8 if not redo_step: break if redo_step: @@ -192,13 +216,9 @@ def __call__(self, rhs, initial_time, initial_state, constants, timestep): "Failed to integrate system from {} to {} ".format(self.dTime, self.dTime + timestep) + "to the tolerances required: rtol={}, atol={}".format(self.rtol, self.atol) ) - - if self.initial_rhs is None: - self.initial_rhs = rhs(initial_time, initial_state, **constants) - - if self.final_rhs is None: - self.final_rhs = rhs(initial_time + self.dTime, initial_state + self.dState, **constants) - + + self._requires_high_precision = False + return timestep, (self.dTime, self.dState) @@ -213,22 +233,23 @@ def algebraic_system(self, next_state, rhs, initial_time, initial_state, timeste return __states def algebraic_system_jacobian(self, next_state, rhs, initial_time, initial_state, timestep, constants): - __aux_states = D.ar_numpy.reshape(next_state, self.stage_values.shape) + if self._requires_high_precision: + __aux_states = D.ar_numpy.reshape(next_state, self.stage_values.shape) __step = self.numel if self.__jac_eye is None: self.__jac_eye = D.ar_numpy.eye(self.tableau_intermediate.shape[0] * __step, **self.array_constructor_kwargs) self.__jac = D.ar_numpy.copy(self.__jac_eye) D.ar_numpy.copyto(self.__jac, self.__jac_eye) - __rhs_jac = D.ar_numpy.stack([ - rhs.jac(initial_time + tbl[0] * timestep, - initial_state + timestep * D.ar_numpy.sum(tbl[1:] * __aux_states, axis=-1), - **constants) - for tbl in self.tableau_intermediate - ]) for idx in range(0, self.__jac.shape[0], __step): + if self._requires_high_precision: + tbl = self.tableau_intermediate[idx // __step] + jac_block = rhs.jac(initial_time + tbl[0] * timestep, + initial_state + timestep * D.ar_numpy.sum(tbl[1:] * __aux_states, axis=-1), + **constants).reshape(__step, __step) + else: + jac_block = self.__rhs_jac.reshape(__step, __step) for jdx in range(0, self.__jac.shape[1], __step): - self.__jac[idx:idx + __step, jdx:jdx + __step] -= timestep * self.tableau_intermediate[ - idx // __step, 1 + jdx // __step] * __rhs_jac[idx // __step].reshape(__step, __step) + self.__jac[idx:idx + __step, jdx:jdx + __step] -= timestep * self.tableau_intermediate[idx // __step, 1 + jdx // __step] * jac_block __jac = self.__jac if self.__jac.shape[0] == 1 and self.__jac.shape[1] == 1: __jac = D.ar_numpy.reshape(__jac, tuple()) @@ -247,39 +268,39 @@ def step(self, rhs, initial_time, initial_state, constants, timestep): constants ) - if D.ar_numpy.sum(self.stage_values[...,0]) == 0.0: - self.initial_rhs = self.stage_values[...,0] - else: - self.initial_rhs = None - if self.is_implicit: initial_guess = self.stage_values - - desired_tol = D.ar_numpy.max(D.ar_numpy.abs(self.atol * 1e-1 + D.ar_numpy.max(D.ar_numpy.abs(self.rtol * 1e-1 * initial_state)))) - aux_root, (success, num_iter, _, _, prec) = \ + if self.__rhs_jac is None: + self.__rhs_jac = rhs.jac(initial_time, initial_state, **constants) + desired_tol = D.ar_numpy.max(D.ar_numpy.abs(self.atol + D.ar_numpy.max(D.ar_numpy.abs(self.rtol * initial_state)))) * 0.5 + aux_root, (self.solver_dict["newton_iteration_success"], num_iter, _, _, prec) = \ utilities.optimizer.nonlinear_roots( self.algebraic_system, initial_guess, jac=self.algebraic_system_jacobian, verbose=False, - tol=desired_tol, maxiter=8, + tol=desired_tol, maxiter=32, additional_args=(rhs, initial_time, initial_state, timestep, constants)) - - if not success and prec > desired_tol: - raise exception_types.FailedToMeetTolerances( - "Step size too large, cannot solve system to the " - "tolerances required: achieved = {}, desired = {}, iter = {}".format(prec, desired_tol, num_iter)) + + self.solver_dict["newton_iteration_success"] = self.solver_dict["newton_iteration_success"] and prec < desired_tol + if not self.solver_dict["newton_iteration_success"]: + self.__rhs_jac = None self.solver_dict.update(dict( tau0=self.solver_dict['tau1'], tau1=timestep, - niter0=self.solver_dict['niter1'], niter1=num_iter + niter0=self.solver_dict['niter1'], niter1=num_iter, + newton_prec0=self.solver_dict['newton_prec1'], newton_prec1=prec )) self.stage_values = D.ar_numpy.reshape(aux_root, self.stage_values.shape) + self.dTime = D.ar_numpy.copy(timestep) if self.is_fsal and self.is_explicit: self.dState = intermediate_dstate self.final_rhs = intermediate_rhs else: self.dState = timestep * D.ar_numpy.sum(self.stage_values * self.tableau_final[0, 1:], axis=-1) - self.dTime = D.ar_numpy.copy(timestep) + self.final_rhs = rhs(initial_time + self.dTime, initial_state + self.dState, **constants) + + if self.is_implicit and self.__rhs_jac is not None: + self.__rhs_jac = broyden_update_jac(self.__rhs_jac.reshape(self.numel, self.numel), self.dState.reshape(self.numel, 1), (self.final_rhs - self.initial_rhs).reshape(self.numel, 1)).reshape(self.__rhs_jac.shape) return timestep, (self.dTime, self.dState) diff --git a/desolver/integrators/utilities.py b/desolver/integrators/utilities.py index 715ba3e..698c797 100644 --- a/desolver/integrators/utilities.py +++ b/desolver/integrators/utilities.py @@ -16,12 +16,18 @@ def implicit_aware_update_timestep(integrator: TableauIntegrator): dCTk = dnCTk / ddCTk else: dCTk = D.ar_numpy.zeros_like(integrator.solver_dict['timestep']) - tau2 = timestep * D.ar_numpy.exp(-safety_factor * dCTk) + tau2 = D.ar_numpy.exp(-safety_factor * dCTk) else: - tau2 = timestep - if tau2 < timestep_from_error: - return tau2, redo_step - else: - return timestep_from_error, redo_step + tau2 = D.numpy.inf + total_error_tolerance = integrator.solver_dict['atol'] + integrator.solver_dict['rtol'] + with D.numpy.errstate(divide='ignore'): + epsilon_current = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(integrator.solver_dict['newton_prec1'] / total_error_tolerance)) + tau3 = D.ar_numpy.where(epsilon_current > 0.0, epsilon_current ** (1.0 / integrator.solver_dict['order']), 1.0) + if integrator.solver_dict['newton_prec0'] > 0.0: + with D.numpy.errstate(divide='ignore'): + epsilon_last = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(integrator.solver_dict['newton_prec0'] / total_error_tolerance)) + tau3 = tau3*D.ar_numpy.where(epsilon_last > 0.0, epsilon_last ** (1.0 / integrator.solver_dict['order']), 1.0) + tau = (1 + D.ar_numpy.arctan(D.ar_numpy.minimum(tau2, tau3) - 1)) + return D.ar_numpy.minimum(tau * timestep, timestep_from_error), redo_step else: return timestep_from_error, redo_step diff --git a/desolver/tests/test_differential_system.py b/desolver/tests/test_differential_system.py index e8cff30..8206631 100644 --- a/desolver/tests/test_differential_system.py +++ b/desolver/tests/test_differential_system.py @@ -147,20 +147,26 @@ def test_integration_and_representation_with_jac(dtype_var, backend_var, integra if a.integrator.order <= 4: pytest.skip(f"{a.integrator}'s order is too low") + + if D.ar_numpy.finfo(dtype_var).eps > 64: + tol = a.atol = a.rtol = 1e-12 + test_tol = (tol*32)**0.5 + else: + test_tol = D.tol_epsilon(dtype_var) ** 0.5 - a.integrate() + a.integrate(eta=True) assert (a.integration_status == "Integration completed successfully.") print(str(a)) print(repr(a)) try: - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= test_tol) for i in a: - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= test_tol) assert (len(a.y) == len(a)) assert (len(a.t) == len(a)) @@ -173,18 +179,23 @@ def test_integration_and_representation_with_jac(dtype_var, backend_var, integra assert (a.integration_status == "Integration has not been run.") a.equ_rhs.unhook_jacobian_call() + if D.ar_numpy.finfo(dtype_var).eps > 64: + tol = a.atol = a.rtol = 1e-12 + test_tol = (tol*32)**0.5 + else: + test_tol = D.tol_epsilon(dtype_var) ** 0.5 a.integrate() assert (a.integration_status == "Integration completed successfully.") print(str(a)) print(repr(a)) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= test_tol) for i in a: - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= test_tol) assert (len(a.y) == len(a)) assert (len(a.t) == len(a)) @@ -660,3 +671,35 @@ def jac(t, state, **kwargs): assert (D.ar_numpy.allclose(rhs_matrix @ x0, wrapped_rhs_with_jac(0.0, x0), D.epsilon(np.float64)**0.5)) assert (D.ar_numpy.allclose(rhs_matrix, wrapped_rhs_with_jac.jac(0.0, x0), D.epsilon(np.float64)**0.5)) assert (jac_called) + + +@pytest.mark.parametrize('integrator', [(de.integrators.RK45CKSolver, 'RK45'), (de.integrators.RadauIIA19, 'Radau'), (de.integrators.RK8713MSolver, 'LSODA')]) +def test_solve_ivp_parity(integrator): + from scipy.integrate import solve_ivp + + de_mat = np.array([[0.0, 1.0], [-1.0, 0.0]], dtype=np.float64) + + def fun(t, state): + t = np.atleast_1d(t) + return de_mat @ state + np.concatenate([np.zeros_like(t), t], axis=0) - 0.001*state**2 + + t_span = [0.0, 10.0] + y0 = np.array([0.0, 1.0], dtype=np.float64) + atol = rtol = 1e-10 + + desolver_res = de.solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, method=integrator[0]) + scipy_res = solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, method=integrator[1]) + + print(desolver_res) + print(scipy_res) + test_tol = 1e-6 + + print(scipy_res.t[0] - desolver_res.t[0]) + assert np.allclose(scipy_res.t[0], desolver_res.t[0], test_tol, test_tol) + print(scipy_res.t[-1] - desolver_res.t[-1]) + assert np.allclose(scipy_res.t[-1], desolver_res.t[-1], test_tol, test_tol) + print(scipy_res.y[...,0] - desolver_res.y[...,0]) + assert np.allclose(scipy_res.y[...,0], desolver_res.y[...,0], test_tol, test_tol) + print(scipy_res.y[...,-1] - desolver_res.y[...,-1]) + assert np.allclose(scipy_res.y[...,-1], desolver_res.y[...,-1], test_tol, test_tol) + \ No newline at end of file diff --git a/desolver/tests/test_event_detection.py b/desolver/tests/test_event_detection.py index 8365d6c..98a703e 100644 --- a/desolver/tests/test_event_detection.py +++ b/desolver/tests/test_event_detection.py @@ -2,7 +2,6 @@ import desolver.backend as D import numpy as np import pytest -from copy import deepcopy from desolver.tests import common diff --git a/desolver/utilities/optimizer.py b/desolver/utilities/optimizer.py index 2457881..ca6e9b8 100644 --- a/desolver/utilities/optimizer.py +++ b/desolver/utilities/optimizer.py @@ -1,6 +1,8 @@ import numpy +import warnings import numpy as np import scipy.optimize +import scipy.linalg try: import torch torch_available = True @@ -431,9 +433,11 @@ def newtontrustregion(f, x0, jac=None, tol=None, verbose=False, maxiter=200, jac tol = D.tol_epsilon(x0.dtype) xshape = D.ar_numpy.shape(x0) if len(xshape) == 0: - f_vec = lambda x: D.ar_numpy.atleast_1d(f(x[0])) + def f_vec(x): + return D.ar_numpy.atleast_1d(f(x[0])) if jac is not None: - jac_vec = lambda x: D.ar_numpy.atleast_2d(jac(x[0])) + def jac_vec(x): + return D.ar_numpy.atleast_2d(jac(x[0])) else: jac_vec = None res = newtontrustregion(f_vec, D.ar_numpy.atleast_1d(x0), jac_vec, tol=tol, verbose=verbose, @@ -500,7 +504,6 @@ def fun_jac(x): fun_jac = transform_to_bounded_jac(fun_jac, *var_bounds) x = transform_to_bounded_x(x, *var_bounds) - w_relax = 0.5 F0 = fun(x) Jf0 = fun_jac(x) F1, Jf1 = D.ar_numpy.copy(F0), D.ar_numpy.copy(Jf0) @@ -529,7 +532,9 @@ def fun_jac(x): sparse = (1.0 - D.ar_numpy.sum(D.ar_numpy.abs(Jf1) > 0) / (xdim * fdim)) <= 0.7 P = Jf1 diagP = D.ar_numpy.diag(trust_region * D.ar_numpy.diag(P)) - dx = D.ar_numpy.reshape(D.ar_numpy.solve_linear_system(Jinv @ (P + diagP), -Jinv @ F1, sparse=sparse), (xdim, 1)) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=scipy.linalg.LinAlgWarning) + dx = D.ar_numpy.reshape(D.ar_numpy.solve_linear_system(Jinv @ (P + diagP), -Jinv @ F1, sparse=sparse), (xdim, 1)) no_progress = True F0 = F1 Fn0 = Fn1 @@ -650,7 +655,9 @@ def fun_jac(x): Fn0 = D.ar_numpy.linalg.norm(F0) print(f"[hybrj-{iteration}]: tr = {D.ar_numpy.to_numpy(trust_region)}, x = {D.ar_numpy.to_numpy(x)}, f = {D.ar_numpy.to_numpy(F1)}, ||dx|| = {D.ar_numpy.to_numpy(dxn)}, ||F|| = {D.ar_numpy.to_numpy(Fn0)}, ||dF|| = {D.ar_numpy.to_numpy(df)}") Jt_mul_F = J0.mT @ F0 - dx_gn = -D.ar_numpy.solve_linear_system(J0.mT @ J0, Jt_mul_F) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=scipy.linalg.LinAlgWarning) + dx_gn = -D.ar_numpy.solve_linear_system(J0.mT @ J0, Jt_mul_F) dx_sd = -Jt_mul_F tparam = -dx_sd.mT @ Jt_mul_F / D.ar_numpy.linalg.norm(J0 @ dx_sd) ** 2 xtol = tol * (xdim + D.ar_numpy.linalg.norm(x)) @@ -710,9 +717,11 @@ def nonlinear_roots(f, x0, jac=None, tol=None, verbose=False, maxiter=200, use_s tol = D.tol_epsilon(x0.dtype) xshape = D.ar_numpy.shape(x0) if len(xshape) == 0: - f_vec = lambda x: D.ar_numpy.atleast_1d(f(x[0])) + def f_vec(x): + return D.ar_numpy.atleast_1d(f(x[0])) if jac is not None: - jac_vec = lambda x: D.ar_numpy.atleast_2d(jac(x[0])) + def jac_vec(x): + return D.ar_numpy.atleast_2d(jac(x[0])) else: jac_vec = None res = nonlinear_roots(f_vec, D.ar_numpy.atleast_1d(x0), jac_vec, tol=tol, verbose=verbose, maxiter=maxiter) @@ -743,7 +752,8 @@ def fun(x): else: __fun_jac = utilities.JacobianWrapper(fun, atol=tol, rtol=tol, flat=True) else: - __fun_jac = lambda x: jac(x, *additional_args, **additional_kwargs) + def __fun_jac(x): + return jac(x, *additional_args, **additional_kwargs) jac_shape = D.ar_numpy.shape(__fun_jac(x0)) jacdim = 1 From 8a6916d23f3b587354710038a34f49ecf24dbeb2 Mon Sep 17 00:00:00 2001 From: Ekin Ozturk Date: Thu, 6 Mar 2025 17:54:39 +0000 Subject: [PATCH 2/4] Fixed several bugs relating to timestepping, added proper filtering of unimportant warnings, added tests for reverse integration and skipped very slow tests (>2min) --- desolver/backend/common.py | 3 + desolver/backend/torch_backend.py | 1 + desolver/differential_system.py | 5 +- desolver/integrators/integrator_types.py | 44 +++-- desolver/integrators/utilities.py | 33 ++-- desolver/tests/common.py | 2 +- desolver/tests/test_differential_system.py | 188 +++++++++++++++++---- desolver/tests/test_event_detection.py | 4 +- desolver/utilities/optimizer.py | 15 +- desolver/utilities/utilities.py | 28 +-- 10 files changed, 242 insertions(+), 81 deletions(-) diff --git a/desolver/backend/common.py b/desolver/backend/common.py index 5e1760d..aa86ba5 100644 --- a/desolver/backend/common.py +++ b/desolver/backend/common.py @@ -1,11 +1,14 @@ import numpy __all__ = [ + 'linear_algebra_exceptions', 'e', 'euler_gamma', 'pi', ] +linear_algebra_exceptions = [numpy.linalg.LinAlgError] + # Constants e = 2.7182818284590452353602874713526624977572470936999595749669676277240766303535475945713821785251664274 euler_gamma = 0.5772156649015328606065120900824024310421593359399235988057672348848677267776646709369470632917467495 diff --git a/desolver/backend/torch_backend.py b/desolver/backend/torch_backend.py index 4ea41e7..6e5ff33 100644 --- a/desolver/backend/torch_backend.py +++ b/desolver/backend/torch_backend.py @@ -3,6 +3,7 @@ import torch import autoray +linear_algebra_exceptions.append(torch._C._LinAlgError) def __solve_linear_system(A, b, sparse=False): __A = A diff --git a/desolver/differential_system.py b/desolver/differential_system.py index 5a367d9..3b79599 100644 --- a/desolver/differential_system.py +++ b/desolver/differential_system.py @@ -1026,7 +1026,7 @@ def integrate(self, t=None, callback=None, eta=False, events=None): end_int = False self.__allocate_soln_space(total_steps) try: - while (implicit_integration or self.dt != 0 and D.ar_numpy.abs(tf - self.__t[self.counter]) >= D.tol_epsilon(self.__y[self.counter].dtype)) and not end_int: + while (implicit_integration or (self.dt != 0 and D.ar_numpy.abs(tf - self.__t[self.counter]) >= D.tol_epsilon(self.__y[self.counter].dtype))) and not end_int: if not implicit_integration and D.ar_numpy.abs(self.dt + self.__t[self.counter]) > D.ar_numpy.abs(tf): self.dt = (tf - self.__t[self.counter]) self.dt, (dTime, dState) = self.integrator(self.equ_rhs, self.__t[self.counter], self.__y[self.counter], @@ -1130,6 +1130,7 @@ def __repr__(self): return "\n".join([ """{:>10}: {:<128}""".format("message", self.integration_status), """{:>10}: {:<128}""".format("nfev", str(self.nfev)), + """{:>10}: {:<128}""".format("njev", str(self.njev)), """{:>10}: {:<128}""".format("sol", str(self.sol)), """{:>10}: {:<128}""".format("t0", str(self.t0)), """{:>10}: {:<128}""".format("tf", str(self.tf)), @@ -1147,6 +1148,7 @@ def _repr_markdown_(self): {:>10}: {:<128} {:>10}: {:<128} {:>10}: {:<128} +{:>10}: {:<128} {:>10}: ``` {} @@ -1158,6 +1160,7 @@ def _repr_markdown_(self): """.format( "message", self.integration_status, "nfev", str(self.nfev), + "njev", str(self.njev), "sol", str(self.sol), "t0", str(self.t0), "tf", str(self.tf), diff --git a/desolver/integrators/integrator_types.py b/desolver/integrators/integrator_types.py index 289abe2..c870522 100644 --- a/desolver/integrators/integrator_types.py +++ b/desolver/integrators/integrator_types.py @@ -5,6 +5,8 @@ from desolver.integrators import utilities as integrator_utilities from desolver.integrators import components from desolver.utilities.optimizer import broyden_update_jac + +import warnings import abc __all__ = [ @@ -24,8 +26,6 @@ def __init__(self, sys_dim, dtype, rtol=None, atol=None, device=None): self.numel = 1 for i in self.dim: self.numel *= int(i) - self.rtol = rtol if rtol is not None else 32 * D.epsilon(dtype) - self.atol = atol if atol is not None else 32 * D.epsilon(dtype) self.dtype = dtype self.device = device self.array_constructor_kwargs = dict(dtype=self.dtype) @@ -34,6 +34,8 @@ def __init__(self, sys_dim, dtype, rtol=None, atol=None, device=None): if self.array_constructor_kwargs['like'] == 'torch': self.array_constructor_kwargs['device'] = self.device # ---- # + self.rtol = D.ar_numpy.ones((1,), **self.array_constructor_kwargs)[0]*rtol if rtol is not None else 32 * D.epsilon(dtype) + self.atol = D.ar_numpy.ones((1,), **self.array_constructor_kwargs)[0]*atol if atol is not None else 32 * D.epsilon(dtype) self.dState = D.ar_numpy.zeros(self.dim, **self.array_constructor_kwargs) self.dTime = D.ar_numpy.zeros(tuple(), **self.array_constructor_kwargs) self.tableau_intermediate = D.ar_numpy.asarray(self.__class__.tableau_intermediate, **self.array_constructor_kwargs) @@ -141,12 +143,15 @@ def __init__(self, sys_dim, dtype, rtol=None, atol=None, device=None): self.solver_dict.update(dict( initial_state=self.stage_values[...,0], diff=D.ar_numpy.zeros(sys_dim, **self.array_constructor_kwargs), - timestep=1.0, - dState=self.stage_values[...,0] + timestep=D.ar_numpy.ones((1,), **self.array_constructor_kwargs)[0], + dState=self.stage_values[...,0], + num_step_retries=64 )) if not self._explicit: solver_dict_preserved.update(dict( - tau0=2, tau1=2, niter0=0, niter1=0, newton_prec0=0.0, newton_prec1=0.0 + tau0=D.ar_numpy.ones((1,), **self.array_constructor_kwargs)[0], tau1=D.ar_numpy.ones((1,), **self.array_constructor_kwargs)[0], niter0=0, niter1=0, + newton_prec0=D.ar_numpy.zeros((1,), **self.array_constructor_kwargs)[0], newton_prec1=D.ar_numpy.zeros((1,), **self.array_constructor_kwargs)[0], + newton_iterations=32 )) self.solver_dict.update(solver_dict_preserved) self.adaptation_fn = integrator_utilities.implicit_aware_update_timestep @@ -170,13 +175,14 @@ def __call__(self, rhs, initial_time, initial_state, constants, timestep): if self.is_implicit and self.__rhs_jac is None: self.__rhs_jac = rhs.jac(initial_time, initial_state, **constants) + current_timestep = timestep try: timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, - timestep) - except (D.numpy.linalg.LinAlgError, ValueError): + current_timestep) + except (*D.linear_algebra_exceptions, ValueError): self._requires_high_precision = True timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, - timestep) + current_timestep) self._requires_high_precision = False if self.is_adaptive or self.is_implicit: @@ -193,15 +199,15 @@ def __call__(self, rhs, initial_time, initial_state, constants, timestep): redo_step = True timestep = timestep * 0.8 if redo_step: - for _ in range(64): + for _ in range(self.solver_dict.get("num_step_retries", 64)): self.solver_dict['redo_count'] += 1 try: timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, - timestep) - except (D.numpy.linalg.LinAlgError, ValueError): + D.ar_numpy.minimum(timestep, current_timestep)) + except (*D.linear_algebra_exceptions, ValueError): self._requires_high_precision = True timestep, (self.dTime, self.dState) = self.step(rhs, initial_time, initial_state, constants, - timestep) + D.ar_numpy.minimum(timestep, current_timestep)) self.solver_dict['diff'] = timestep * self.get_error_estimate() self.solver_dict['timestep'] = self.dTime self.solver_dict['dState'] = self.dState @@ -249,7 +255,10 @@ def algebraic_system_jacobian(self, next_state, rhs, initial_time, initial_state else: jac_block = self.__rhs_jac.reshape(__step, __step) for jdx in range(0, self.__jac.shape[1], __step): - self.__jac[idx:idx + __step, jdx:jdx + __step] -= timestep * self.tableau_intermediate[idx // __step, 1 + jdx // __step] * jac_block + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in matmul") + warnings.filterwarnings("ignore", category=RuntimeWarning, message="overflow encountered in subtract") + self.__jac[idx:idx + __step, jdx:jdx + __step] -= timestep * self.tableau_intermediate[idx // __step, 1 + jdx // __step] * jac_block __jac = self.__jac if self.__jac.shape[0] == 1 and self.__jac.shape[1] == 1: __jac = D.ar_numpy.reshape(__jac, tuple()) @@ -277,9 +286,8 @@ def step(self, rhs, initial_time, initial_state, constants, timestep): utilities.optimizer.nonlinear_roots( self.algebraic_system, initial_guess, jac=self.algebraic_system_jacobian, verbose=False, - tol=desired_tol, maxiter=32, + tol=desired_tol, maxiter=self.solver_dict.get("newton_iterations", 32), additional_args=(rhs, initial_time, initial_state, timestep, constants)) - self.solver_dict["newton_iteration_success"] = self.solver_dict["newton_iteration_success"] and prec < desired_tol if not self.solver_dict["newton_iteration_success"]: self.__rhs_jac = None @@ -300,7 +308,11 @@ def step(self, rhs, initial_time, initial_state, constants, timestep): self.final_rhs = rhs(initial_time + self.dTime, initial_state + self.dState, **constants) if self.is_implicit and self.__rhs_jac is not None: - self.__rhs_jac = broyden_update_jac(self.__rhs_jac.reshape(self.numel, self.numel), self.dState.reshape(self.numel, 1), (self.final_rhs - self.initial_rhs).reshape(self.numel, 1)).reshape(self.__rhs_jac.shape) + self.__rhs_jac = broyden_update_jac( + self.__rhs_jac.reshape(self.numel, self.numel), + self.dState.reshape(self.numel, 1), + (self.final_rhs - self.initial_rhs).reshape(self.numel, 1) + ).reshape(self.__rhs_jac.shape) return timestep, (self.dTime, self.dState) diff --git a/desolver/integrators/utilities.py b/desolver/integrators/utilities.py index 698c797..1d75f45 100644 --- a/desolver/integrators/utilities.py +++ b/desolver/integrators/utilities.py @@ -4,30 +4,41 @@ def implicit_aware_update_timestep(integrator: TableauIntegrator): timestep = integrator.solver_dict['timestep'] - safety_factor = integrator.solver_dict['safety_factor'] timestep_from_error, redo_step = integrator.update_timestep(ignore_custom_adaptation=True) if "niter0" in integrator.solver_dict.keys(): + # Adjust the timestep according to the computational cost of + # solving the nonlinear system at each timestep if integrator.solver_dict['niter0'] != 0 and integrator.solver_dict['niter1'] != 0: Tk0, CTk0 = D.ar_numpy.log(integrator.solver_dict['tau0']), math.log(integrator.solver_dict['niter0']) Tk1, CTk1 = D.ar_numpy.log(integrator.solver_dict['tau1']), math.log(integrator.solver_dict['niter1']) - dnCTk = D.ar_numpy.asarray(CTk1 - CTk0, **integrator.array_constructor_kwargs) - ddCTk = D.ar_numpy.asarray(Tk1 - Tk0, **integrator.array_constructor_kwargs) + dnCTk = CTk1 - CTk0 + ddCTk = Tk1 - Tk0 if ddCTk > 0: dCTk = dnCTk / ddCTk else: dCTk = D.ar_numpy.zeros_like(integrator.solver_dict['timestep']) - tau2 = D.ar_numpy.exp(-safety_factor * dCTk) + tau2 = D.ar_numpy.exp(-dCTk) else: - tau2 = D.numpy.inf + tau2 = None + # ---- # + # Adjust the timestep according to the precision achieved by the + # nonlinear system solver at each timestep total_error_tolerance = integrator.solver_dict['atol'] + integrator.solver_dict['rtol'] - with D.numpy.errstate(divide='ignore'): - epsilon_current = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(integrator.solver_dict['newton_prec1'] / total_error_tolerance)) - tau3 = D.ar_numpy.where(epsilon_current > 0.0, epsilon_current ** (1.0 / integrator.solver_dict['order']), 1.0) + tau3 = D.ar_numpy.ones_like(integrator.solver_dict['timestep']) + if integrator.solver_dict['newton_prec1'] > 0.0: + with D.numpy.errstate(divide='ignore'): + epsilon_current = total_error_tolerance / integrator.solver_dict['newton_prec1'] + tau3 = tau3*D.ar_numpy.where(D.ar_numpy.isfinite(epsilon_current), epsilon_current, 1.0) if integrator.solver_dict['newton_prec0'] > 0.0: with D.numpy.errstate(divide='ignore'): - epsilon_last = D.ar_numpy.reciprocal(D.ar_numpy.linalg.norm(integrator.solver_dict['newton_prec0'] / total_error_tolerance)) - tau3 = tau3*D.ar_numpy.where(epsilon_last > 0.0, epsilon_last ** (1.0 / integrator.solver_dict['order']), 1.0) - tau = (1 + D.ar_numpy.arctan(D.ar_numpy.minimum(tau2, tau3) - 1)) + epsilon_last = total_error_tolerance / integrator.solver_dict['newton_prec0'] + tau3 = tau3*D.ar_numpy.where(D.ar_numpy.isfinite(epsilon_last), epsilon_last, 1.0) + # ---- # + if tau2 is None: + tau = tau3 + else: + tau = D.ar_numpy.minimum(tau2, tau3) + tau = (1 + D.ar_numpy.arctan(tau - 1)) return D.ar_numpy.minimum(tau * timestep, timestep_from_error), redo_step else: return timestep_from_error, redo_step diff --git a/desolver/tests/common.py b/desolver/tests/common.py index 4ab7580..fd6fc1d 100644 --- a/desolver/tests/common.py +++ b/desolver/tests/common.py @@ -66,7 +66,7 @@ def analytic_soln(t, initial_conditions): integrator = a.method else: a.method = integrator - dt = D.tol_epsilon(dtype_var)**(0.75/(2+a.integrator.order))/(2*D.pi) + dt = D.tol_epsilon(dtype_var)**(1.0/(2+a.integrator.order))/(2*D.pi) a.dt = dt return de_mat, rhs, analytic_soln, y_init, dt, a diff --git a/desolver/tests/test_differential_system.py b/desolver/tests/test_differential_system.py index 8206631..27577ce 100644 --- a/desolver/tests/test_differential_system.py +++ b/desolver/tests/test_differential_system.py @@ -88,8 +88,7 @@ def rhs(t, state, **kwargs): @common.integrator_param def test_integration_and_representation_no_jac(dtype_var, backend_var, integrator): - if integrator in [de.integrators.RadauIIA5] and dtype_var in ["longdouble", "float64"]: - pytest.skip("This test is too slow for the precision required") + print() dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) if backend_var == 'torch': import torch @@ -101,22 +100,34 @@ def test_integration_and_representation_no_jac(dtype_var, backend_var, integrato assert (a.integration_status == "Integration has not been run.") - if a.integrator.order <= 4: - pytest.skip(f"{a.integrator}'s order is too low") + if a.integrator.is_implicit and D.ar_numpy.finfo(dtype_var).bits < 32: + pytest.skip(f"{a.integrator} is unstable for {D.ar_numpy.finfo(dtype_var).bits}-bit precision") + elif a.integrator.order <= 6 and D.ar_numpy.finfo(dtype_var).bits > 32: + pytest.skip(f"{a.integrator}'s order is too low for {D.ar_numpy.finfo(dtype_var).bits}-bit precision") + elif a.integrator.is_implicit and D.ar_numpy.finfo(dtype_var).bits > 64: + pytest.skip(f"{a.integrator}'s is too slow for {D.ar_numpy.finfo(dtype_var).bits}-bit precision") - a.integrate() + if D.ar_numpy.finfo(dtype_var).eps > 64: + tol = a.atol = a.rtol = 1e-12 + test_tol = (tol*32)**0.5 + else: + test_tol = D.tol_epsilon(dtype_var) ** 0.5 + if a.integrator.order <= 6: + test_tol = 128 * test_tol + + a.integrate(eta=True) assert (a.integration_status == "Integration completed successfully.") print(str(a)) print(repr(a)) try: - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= test_tol) for i in a: - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= test_tol) assert (len(a.y) == len(a)) assert (len(a.t) == len(a)) @@ -132,8 +143,7 @@ def test_integration_and_representation_no_jac(dtype_var, backend_var, integrato @common.implicit_integrator_param def test_integration_and_representation_with_jac(dtype_var, backend_var, integrator): - if integrator in [de.integrators.RadauIIA5] and dtype_var in ["longdouble", "float64"]: - pytest.skip("This test is too slow for the precision required") + print() dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) if backend_var == 'torch': import torch @@ -144,15 +154,21 @@ def test_integration_and_representation_with_jac(dtype_var, backend_var, integra a.tf = D.pi/4 assert (a.integration_status == "Integration has not been run.") - - if a.integrator.order <= 4: - pytest.skip(f"{a.integrator}'s order is too low") - + + if a.integrator.is_implicit and D.ar_numpy.finfo(dtype_var).bits < 32: + pytest.skip(f"{a.integrator} is unstable for {D.ar_numpy.finfo(dtype_var).bits}-bit precision") + elif a.integrator.order <= 6 and D.ar_numpy.finfo(dtype_var).bits > 32: + pytest.skip(f"{a.integrator}'s order is too low for {D.ar_numpy.finfo(dtype_var).bits}-bit precision") + elif a.integrator.is_implicit and D.ar_numpy.finfo(dtype_var).bits > 64: + pytest.skip(f"{a.integrator}'s is too slow for {D.ar_numpy.finfo(dtype_var).bits}-bit precision") + if D.ar_numpy.finfo(dtype_var).eps > 64: tol = a.atol = a.rtol = 1e-12 test_tol = (tol*32)**0.5 else: test_tol = D.tol_epsilon(dtype_var) ** 0.5 + if a.integrator.order <= 6: + test_tol = 128 * test_tol a.integrate(eta=True) @@ -174,33 +190,41 @@ def test_integration_and_representation_with_jac(dtype_var, backend_var, integra if backend_var == 'torch': # Test rehooking of jac through autodiff - (de_mat, rhs, analytic_soln, y_init, dt, a) = common.set_up_basic_system(dtype_var, backend_var, integrator=integrator, hook_jacobian=True) + (de_mat, rhs, analytic_soln, y_init, dt, a_torch) = common.set_up_basic_system(dtype_var, backend_var, integrator=integrator, hook_jacobian=True) + a_torch.tf = D.pi/4 - assert (a.integration_status == "Integration has not been run.") + assert (a_torch.integration_status == "Integration has not been run.") - a.equ_rhs.unhook_jacobian_call() + a_torch.equ_rhs.unhook_jacobian_call() + + for i in a: + assert (D.ar_numpy.max(D.ar_numpy.abs(a_torch.equ_rhs.jac(i.t, i.y) - a.equ_rhs.jac(i.t, i.y))) <= test_tol) + if D.ar_numpy.finfo(dtype_var).eps > 64: - tol = a.atol = a.rtol = 1e-12 + tol = a_torch.atol = a_torch.rtol = 1e-12 test_tol = (tol*32)**0.5 else: test_tol = D.tol_epsilon(dtype_var) ** 0.5 - a.integrate() + if a_torch.integrator.order <= 4: + test_tol = 128 * test_tol + + a_torch.integrate(eta=True) - assert (a.integration_status == "Integration completed successfully.") + assert (a_torch.integration_status == "Integration completed successfully.") - print(str(a)) - print(repr(a)) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= test_tol) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= test_tol) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= test_tol) + print(str(a_torch)) + print(repr(a_torch)) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a_torch.sol(a_torch.t[0])) - D.ar_numpy.to_numpy(y_init))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a_torch.sol(a_torch.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a_torch.t[-1], y_init)))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a_torch.sol(a_torch.t).T) - D.ar_numpy.to_numpy(analytic_soln(a_torch.t, y_init)))) <= test_tol) - for i in a: + for i in a_torch: assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= test_tol) - assert (len(a.y) == len(a)) - assert (len(a.t) == len(a)) - assert (a.nfev > 0) - assert (a.njev > 0) + assert (len(a_torch.y) == len(a_torch)) + assert (len(a_torch.t) == len(a_torch)) + assert (a_torch.nfev > 0) + assert (a_torch.njev > 0) except AssertionError as e: if backend_var == 'torch' and D.ar_numpy.finfo(dtype_var).bits < 32: pytest.xfail(f"Low precision {dtype_var} can fail some of the tests: {e}") @@ -212,6 +236,7 @@ def test_integration_and_representation_with_jac(dtype_var, backend_var, integra @common.basic_explicit_integrator_param def test_integration_with_richardson(dtype_var, backend_var, integrator): + print() dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) if backend_var == 'torch': import torch @@ -224,6 +249,7 @@ def test_integration_with_richardson(dtype_var, backend_var, integrator): a.method = de.integrators.generate_richardson_integrator(a.method, richardson_iter=4) + test_tol = D.tol_epsilon(dtype_var) ** 0.5 a.integrate() assert (a.integration_status == "Integration completed successfully.") @@ -231,12 +257,12 @@ def test_integration_with_richardson(dtype_var, backend_var, integrator): print(str(a)) print(repr(a)) try: - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= test_tol) for i in a: - assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= D.tol_epsilon(dtype_var) ** 0.5) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= test_tol) assert (len(a.y) == len(a)) assert (len(a.t) == len(a)) @@ -250,6 +276,7 @@ def test_integration_with_richardson(dtype_var, backend_var, integrator): def test_integration_and_nearest_float_no_dense_output(dtype_var, backend_var, device_var): + print() dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) if backend_var == 'torch': import torch @@ -285,6 +312,7 @@ def rhs(t, state, k, **kwargs): def test_integration_reset(dtype_var, backend_var, device_var): + print() dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) if backend_var == 'torch': import torch @@ -331,6 +359,7 @@ def rhs(t, state, k, **kwargs): def test_integration_long_duration(dtype_var, backend_var): + print() dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) if backend_var == 'torch': import torch @@ -352,7 +381,7 @@ def rhs(t, state, k, **kwargs): assert (a.integration_status == "Integration has not been run.") - a.integrate() + a.integrate(eta=True) assert (a.sol is None) @@ -512,6 +541,92 @@ def callback(ode_sys): assert(callback_called) +def test_backward_integration(dtype_var, backend_var): + print() + dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) + if backend_var == 'torch': + import torch + torch.set_printoptions(precision=17) + torch.autograd.set_detect_anomaly(True) + + (de_mat, rhs, analytic_soln, y_init, dt, a) = common.set_up_basic_system(dtype_var, backend_var) + a.tf = -2*D.pi + + assert (a.integration_status == "Integration has not been run.") + + if D.ar_numpy.finfo(dtype_var).eps > 64: + tol = a.atol = a.rtol = 1e-12 + test_tol = (tol*32)**0.5 + else: + test_tol = D.tol_epsilon(dtype_var) ** 0.5 + if a.integrator.order <= 6: + test_tol = 128 * test_tol + + a.integrate(eta=True) + + assert (a.integration_status == "Integration completed successfully.") + + print(str(a)) + print(repr(a)) + try: + assert (a.t[-1] < a.t[0]) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[0])) - D.ar_numpy.to_numpy(y_init))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t[-1])) - D.ar_numpy.to_numpy(analytic_soln(a.t[-1], y_init)))) <= test_tol) + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(a.sol(a.t).T) - D.ar_numpy.to_numpy(analytic_soln(a.t, y_init)))) <= test_tol) + + for i in a: + assert (D.ar_numpy.max(D.ar_numpy.abs(D.ar_numpy.to_numpy(i.y) - D.ar_numpy.to_numpy(analytic_soln(i.t, y_init)))) <= test_tol) + + assert (len(a.y) == len(a)) + assert (len(a.t) == len(a)) + assert (a.success) + except AssertionError as e: + if backend_var == 'torch' and D.ar_numpy.finfo(dtype_var).bits < 32: + pytest.xfail(f"Low precision {dtype_var} can fail some of the tests: {e}") + elif backend_var == 'numpy' and D.ar_numpy.finfo(dtype_var).bits < 32: + pytest.xfail(f"Low precision {dtype_var} can fail some of the tests: {e}") + else: + raise e + + +@common.basic_integrator_param +@pytest.mark.parametrize("datatype", ["float32", "float64"]) +def test_mixed_environment(integrator, datatype): + torch = pytest.importorskip("torch") + + np_oscillator_mat = np.array([[0.0, 1.0], [-1.0, 0.0]], dtype=D.autoray.to_backend_dtype(datatype, like='numpy')) + torch_oscillator_mat = torch.tensor(np_oscillator_mat, dtype=D.autoray.to_backend_dtype(datatype, like='torch')) + + def np_rhs(t, state): + t = np.atleast_1d(t) + return np_oscillator_mat @ state + np.concatenate([np.zeros_like(t), t], axis=0) - 0.001*state**2 + + def torch_rhs(t, state): + t = torch.atleast_1d(t) + return torch_oscillator_mat @ state + torch.cat([torch.zeros_like(t), t], dim=0) - 0.001*state**2 + + t_span = [0.0, 10.0] + np_y0 = np.array([0.0, 1.0], dtype=D.autoray.to_backend_dtype(datatype, like='numpy')) + torch_y0 = torch.tensor(np_y0, dtype=D.autoray.to_backend_dtype(datatype, like='torch')) + atol = rtol = 512*D.tol_epsilon(D.autoray.to_backend_dtype(datatype, like='numpy'))**0.5 + + ode_sys_numpy = de.OdeSystem(np_rhs, y0=np_y0, dense_output=False, t=t_span, dt=0.001, atol=atol, rtol=rtol) + ode_sys_numpy.set_kick_vars([False, True]) + ode_sys_numpy.method = integrator + ode_sys_torch = de.OdeSystem(torch_rhs, y0=torch_y0, dense_output=False, t=t_span, dt=0.001, atol=atol, rtol=rtol) + ode_sys_numpy.set_kick_vars([False, True]) + ode_sys_torch.method = integrator + + ode_sys_numpy.integrate(eta=True) + print(repr(ode_sys_numpy)) + + ode_sys_torch.integrate(eta=True) + print(repr(ode_sys_torch)) + + assert np.allclose(ode_sys_numpy.y[0], ode_sys_torch.y[0].numpy(), rtol, atol) + assert np.allclose(ode_sys_numpy.y[-1], ode_sys_torch.y[-1].numpy(), rtol, atol) + + def test_keyboard_interrupt_caught(dtype_var, backend_var): dtype_var = D.autoray.to_backend_dtype(dtype_var, like=backend_var) if backend_var == 'torch': @@ -675,6 +790,7 @@ def jac(t, state, **kwargs): @pytest.mark.parametrize('integrator', [(de.integrators.RK45CKSolver, 'RK45'), (de.integrators.RadauIIA19, 'Radau'), (de.integrators.RK8713MSolver, 'LSODA')]) def test_solve_ivp_parity(integrator): + print() from scipy.integrate import solve_ivp de_mat = np.array([[0.0, 1.0], [-1.0, 0.0]], dtype=np.float64) diff --git a/desolver/tests/test_event_detection.py b/desolver/tests/test_event_detection.py index 98a703e..3ae3bdd 100644 --- a/desolver/tests/test_event_detection.py +++ b/desolver/tests/test_event_detection.py @@ -228,6 +228,8 @@ def stationary_event(t, y, dy, **kwargs): @common.basic_integrator_param @common.dense_output_param def test_event_detection_indefinite_integration(dtype_var, backend_var, integrator, dense_output): + if dtype_var in ["float64", "longdouble"] and integrator == de.integrators.RadauIIA5: + pytest.skip("Too slow") if "float16" in dtype_var: pytest.skip("Event detection with 'float16' types are unreliable due to imprecision") @@ -254,7 +256,7 @@ def stationary_event(t, y, dy, **kwargs): a.method = integrator with de.utilities.BlockTimer(section_label="Integrator Tests") as sttimer: - a.integrate(eta=False, events=stationary_event) + a.integrate(eta=True, events=stationary_event) assert (a.integration_status == "Integration terminated upon finding a triggered event.") diff --git a/desolver/utilities/optimizer.py b/desolver/utilities/optimizer.py index ca6e9b8..52e125c 100644 --- a/desolver/utilities/optimizer.py +++ b/desolver/utilities/optimizer.py @@ -3,6 +3,7 @@ import numpy as np import scipy.optimize import scipy.linalg +import scipy.sparse.linalg try: import torch torch_available = True @@ -415,7 +416,9 @@ def iterative_inverse_7th(A, Ainv0, maxiter=10): def broyden_update_jac(B, dx, df, Binv=None): y_ex = B @ dx y_is = df - kI = (y_is - y_ex) / D.ar_numpy.sum(y_ex.mT @ y_ex) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="overflow encountered in matmul") + kI = (y_is - y_ex) / D.ar_numpy.sum(y_ex.mT @ y_ex) B_new = D.ar_numpy.reshape((1 + kI * B * dx) * B, (df.shape[0], dx.shape[0])) if Binv is not None: Binv_new = Binv + ((dx - Binv @ y_is) / (y_is.mT @ y_is)) @ y_is.mT @@ -516,8 +519,10 @@ def fun_jac(x): f64_type = D.autoray.to_backend_dtype('float64', like=inferred_backend) Jinv = D.ar_numpy.astype(D.ar_numpy.linalg.inv(D.ar_numpy.astype(Jf1, f64_type)), Jf1.dtype) - if D.ar_numpy.linalg.norm(Jinv @ Jf1 - I) < 0.5: - Jinv = iterative_inverse_7th(Jf1, Jinv, maxiter=3) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in matmul") + if D.ar_numpy.linalg.norm(Jinv @ Jf1 - I) < 0.5: + Jinv = iterative_inverse_7th(Jf1, Jinv, maxiter=3) trust_region = 5.0 if initial_trust_region is None else initial_trust_region iteration = 0 fail_iter = 0 @@ -533,7 +538,9 @@ def fun_jac(x): P = Jf1 diagP = D.ar_numpy.diag(trust_region * D.ar_numpy.diag(P)) with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in matmul") warnings.filterwarnings("ignore", category=scipy.linalg.LinAlgWarning) + warnings.filterwarnings("ignore", category=scipy.sparse.linalg.MatrixRankWarning) dx = D.ar_numpy.reshape(D.ar_numpy.solve_linear_system(Jinv @ (P + diagP), -Jinv @ F1, sparse=sparse), (xdim, 1)) no_progress = True F0 = F1 @@ -656,7 +663,9 @@ def fun_jac(x): print(f"[hybrj-{iteration}]: tr = {D.ar_numpy.to_numpy(trust_region)}, x = {D.ar_numpy.to_numpy(x)}, f = {D.ar_numpy.to_numpy(F1)}, ||dx|| = {D.ar_numpy.to_numpy(dxn)}, ||F|| = {D.ar_numpy.to_numpy(Fn0)}, ||dF|| = {D.ar_numpy.to_numpy(df)}") Jt_mul_F = J0.mT @ F0 with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in matmul") warnings.filterwarnings("ignore", category=scipy.linalg.LinAlgWarning) + warnings.filterwarnings("ignore", category=scipy.sparse.linalg.MatrixRankWarning) dx_gn = -D.ar_numpy.solve_linear_system(J0.mT @ J0, Jt_mul_F) dx_sd = -Jt_mul_F tparam = -dx_sd.mT @ Jt_mul_F / D.ar_numpy.linalg.norm(J0 @ dx_sd) ** 2 diff --git a/desolver/utilities/utilities.py b/desolver/utilities/utilities.py index 8859882..46d9e32 100644 --- a/desolver/utilities/utilities.py +++ b/desolver/utilities/utilities.py @@ -116,9 +116,11 @@ def estimate(self, y, *args, dy=None, **kwargs): def richardson(self, y, *args, dy=0.5, factor=4.0, **kwargs): A = [[self.estimate(y, dy=dy * (factor ** -m), *args, **kwargs)] for m in range(self.richardson_iter)] denom = factor ** self.base_order - for m in range(1, self.richardson_iter): - for n in range(1, m): - A[m].append(A[m][n - 1] + (A[m][n - 1] - A[m - 1][n - 1]) / (denom ** n - 1)) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='overflow encountered', category=RuntimeWarning) + for m in range(1, self.richardson_iter): + for n in range(1, m): + A[m].append(A[m][n - 1] + (A[m][n - 1] - A[m - 1][n - 1]) / (denom ** n - 1)) return A[-1][-1] def adaptive_richardson(self, y, *args, dy=0.5, factor=4, **kwargs): @@ -128,15 +130,17 @@ def adaptive_richardson(self, y, *args, dy=0.5, factor=4, **kwargs): factor = 1.0 * factor denom = factor ** self.base_order prev_error = numpy.inf - for m in range(1, self.richardson_iter): - A.append([self.estimate(y, *args, dy=dy * (factor ** (-m)), **kwargs)]) - for n in range(1, m + 1): - A[m].append(A[m][n - 1] + (A[m][n - 1] - A[m - 1][n - 1]) / (denom ** n - 1)) - if m >= 3: - prev_error, t_conv = self.check_converged(A[m][m], A[m][m] - A[m - 1][m - 1], prev_error) - if t_conv: - self.order = self.base_order + m - break + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='overflow encountered', category=RuntimeWarning) + for m in range(1, self.richardson_iter): + A.append([self.estimate(y, *args, dy=dy * (factor ** (-m)), **kwargs)]) + for n in range(1, m + 1): + A[m].append(A[m][n - 1] + (A[m][n - 1] - A[m - 1][n - 1]) / (denom ** n - 1)) + if m >= 3: + prev_error, t_conv = self.check_converged(A[m][m], A[m][m] - A[m - 1][m - 1], prev_error) + if t_conv: + self.order = self.base_order + m + break return A[-2][-1] def check_converged(self, initial_state, diff, prev_error): From e90a6daf71e06755b35f79a501e1515eb03f5751 Mon Sep 17 00:00:00 2001 From: Ekin Ozturk Date: Sun, 9 Mar 2025 20:25:20 +0000 Subject: [PATCH 3/4] Fixed failing tests and improved coverage of `solve_ivp` --- desolver/differential_system.py | 32 ++++++++++++++------ desolver/integrators/utilities.py | 4 +-- desolver/tests/test_differential_system.py | 35 ++++++++++++++++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/desolver/differential_system.py b/desolver/differential_system.py index 3b79599..0c2904d 100644 --- a/desolver/differential_system.py +++ b/desolver/differential_system.py @@ -981,7 +981,7 @@ def integrate(self, t=None, callback=None, eta=False, events=None): behaviour of the system is highly unreliable. This could be due to numerical issues. """ - if t: + if t is not None: tf = t else: tf = self.tf @@ -1028,9 +1028,13 @@ def integrate(self, t=None, callback=None, eta=False, events=None): try: while (implicit_integration or (self.dt != 0 and D.ar_numpy.abs(tf - self.__t[self.counter]) >= D.tol_epsilon(self.__y[self.counter].dtype))) and not end_int: if not implicit_integration and D.ar_numpy.abs(self.dt + self.__t[self.counter]) > D.ar_numpy.abs(tf): - self.dt = (tf - self.__t[self.counter]) - self.dt, (dTime, dState) = self.integrator(self.equ_rhs, self.__t[self.counter], self.__y[self.counter], - self.constants, timestep=self.dt) + is_final_step = True + dt = (tf - self.__t[self.counter]) + else: + is_final_step = False + dt = self.dt + new_dt, (dTime, dState) = self.integrator(self.equ_rhs, self.__t[self.counter], self.__y[self.counter], + self.constants, timestep=dt) if self.counter + 1 >= len(self.__y): total_steps = self.__alloc_space_steps(tf - dTime) + 1 @@ -1108,6 +1112,9 @@ def integrate(self, t=None, callback=None, eta=False, events=None): tqdm_progress_bar.desc = "{:>10.2f} | {:.2f} | {:<10.2e}".format(self.__t[self.counter], tf, self.dt).ljust(8) tqdm_progress_bar.update() + + if not is_final_step: + self.dt = new_dt except KeyboardInterrupt as e: self.__int_status = e @@ -1248,22 +1255,29 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False, if "max_step" in options or "min_step" in options: max_step = options.get("max_step", np.inf) min_step = options.get("min_step", 0.0) - callbacks.append(lambda ode_sys: D.ar_numpy.clip(ode_sys, min=min_step, max=max_step)) + def __step_cb(ode_sys): + ode_sys.dt = D.ar_numpy.clip(ode_sys.dt, min=min_step, max=max_step) + callbacks.insert(0, __step_cb) integration_options = dict(callback=callbacks, events=events, eta=options.get("show_prog_bar", False)) if t_eval is None: ode_system.integrate(**integration_options) + t_res = ode_system.t + y_res = D.ar_numpy.transpose(ode_system.y, axes=[*range(1, len(ode_system.y.shape)), 0]) else: t_eval = D.ar_numpy.sort(t_eval) if t_eval[0] < t_span[0] or t_eval[-1] > t_span[1]: raise ValueError(f"Expected `t_eval` to be in the range [{t_span[0]}, {t_span[1]}]") + t_res = [] + y_res = [] for t in t_eval: ode_system.integrate(t=t, **integration_options) - if t_span[1] > t_eval[-1]: - ode_system.integrate(t=t_span[1], **integration_options) + t_res.append(ode_system[-1].t) + y_res.append(ode_system[-1].y) + t_res = D.ar_numpy.stack(t_res, axis=0) + y_res = D.ar_numpy.stack(y_res, axis=-1) - yres = ode_system.y - return OdeResult(t=ode_system.t, y=D.ar_numpy.transpose(yres, axes=[*range(1, len(yres.shape)), 0]), sol=ode_system.sol, t_events=ode_system.events, + return OdeResult(t=t_res, y=y_res, sol=ode_system.sol, t_events=ode_system.events, y_events=ode_system.events, nfev=ode_system.nfev, njev=ode_system.njev, status=ode_system.integration_status, message=ode_system.integration_status, success=ode_system.success, ode_system=ode_system) diff --git a/desolver/integrators/utilities.py b/desolver/integrators/utilities.py index 1d75f45..2c5e737 100644 --- a/desolver/integrators/utilities.py +++ b/desolver/integrators/utilities.py @@ -38,7 +38,7 @@ def implicit_aware_update_timestep(integrator: TableauIntegrator): tau = tau3 else: tau = D.ar_numpy.minimum(tau2, tau3) - tau = (1 + D.ar_numpy.arctan(tau - 1)) - return D.ar_numpy.minimum(tau * timestep, timestep_from_error), redo_step + tau = (1 + 0.1*D.ar_numpy.arctan((tau - 1)/0.1)) + return tau * timestep_from_error, redo_step else: return timestep_from_error, redo_step diff --git a/desolver/tests/test_differential_system.py b/desolver/tests/test_differential_system.py index 27577ce..4c4b0d4 100644 --- a/desolver/tests/test_differential_system.py +++ b/desolver/tests/test_differential_system.py @@ -818,4 +818,39 @@ def fun(t, state): assert np.allclose(scipy_res.y[...,0], desolver_res.y[...,0], test_tol, test_tol) print(scipy_res.y[...,-1] - desolver_res.y[...,-1]) assert np.allclose(scipy_res.y[...,-1], desolver_res.y[...,-1], test_tol, test_tol) + + t_eval = np.linspace(*t_span, 32) + + desolver_res = de.solve_ivp(fun, t_span=t_span, t_eval=t_eval, y0=y0, atol=atol, rtol=rtol, method=integrator[0]) + scipy_res = solve_ivp(fun, t_span=t_span, t_eval=t_eval, y0=y0, atol=atol, rtol=rtol, method=integrator[1]) + + print(desolver_res) + print(scipy_res) + test_tol = 1e-6 + + print(scipy_res.t - desolver_res.t) + assert np.allclose(scipy_res.t, desolver_res.t, test_tol, test_tol) + print(scipy_res.y - desolver_res.y) + assert np.allclose(scipy_res.y, desolver_res.y, test_tol, test_tol) + + def fun(t, state, k, m): + de_mat = np.array([[0.0, 1.0], [-k/m, 0.0]], dtype=np.float64) + t = np.atleast_1d(t) + return de_mat @ state + np.concatenate([np.zeros_like(t), t], axis=0) - 0.001*state**2 + + desolver_res = de.solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, method=integrator[0], args=(4.0, 0.1)) + scipy_res = solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, method=integrator[1], args=(4.0, 0.1)) + + print(desolver_res) + print(scipy_res) + test_tol = 1e-6 + + print(scipy_res.t[0] - desolver_res.t[0]) + assert np.allclose(scipy_res.t[0], desolver_res.t[0], test_tol, test_tol) + print(scipy_res.t[-1] - desolver_res.t[-1]) + assert np.allclose(scipy_res.t[-1], desolver_res.t[-1], test_tol, test_tol) + print(scipy_res.y[...,0] - desolver_res.y[...,0]) + assert np.allclose(scipy_res.y[...,0], desolver_res.y[...,0], test_tol, test_tol) + print(scipy_res.y[...,-1] - desolver_res.y[...,-1]) + assert np.allclose(scipy_res.y[...,-1], desolver_res.y[...,-1], test_tol, test_tol) \ No newline at end of file From 7e5c0fe799edaa61609d888215cbef3c0f429b43 Mon Sep 17 00:00:00 2001 From: Ekin Ozturk Date: Sun, 9 Mar 2025 21:11:31 +0000 Subject: [PATCH 4/4] Improved test coverage --- desolver/differential_system.py | 19 ++++++++++++------- desolver/tests/test_differential_system.py | 15 ++++++++++++++- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/desolver/differential_system.py b/desolver/differential_system.py index 0c2904d..39f4742 100644 --- a/desolver/differential_system.py +++ b/desolver/differential_system.py @@ -1099,6 +1099,9 @@ def integrate(self, t=None, callback=None, eta=False, events=None): self.__sol.remove_interpolant(0) steps += 1 + + if not is_final_step: + self.dt = new_dt for i in callback: i(self) @@ -1112,9 +1115,6 @@ def integrate(self, t=None, callback=None, eta=False, events=None): tqdm_progress_bar.desc = "{:>10.2f} | {:.2f} | {:<10.2e}".format(self.__t[self.counter], tf, self.dt).ljust(8) tqdm_progress_bar.update() - - if not is_final_step: - self.dt = new_dt except KeyboardInterrupt as e: self.__int_status = e @@ -1247,17 +1247,22 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False, fn_args_kwargs = inspect.getfullargspec(fn) constants = {key:value for key,value in zip(fn_args_kwargs[0][2:], args)} - ode_system = OdeSystem(equ_rhs=fun, y0=y0, t=t_span, dense_output=dense_output, dt=options.get('first_step', 1.0), + max_step = options.get("max_step", np.inf) + min_step = options.get("min_step", 0.0) + + initial_dt = options.get('first_step', 1.0) + initial_dt = D.ar_numpy.minimum(initial_dt, max_step) + initial_dt = D.ar_numpy.maximum(initial_dt, min_step) + + ode_system = OdeSystem(equ_rhs=fun, y0=y0, t=t_span, dense_output=dense_output, dt=initial_dt, atol=options.get('atol', None), rtol=options.get('rtol', None), constants=constants) ode_system.method = method callbacks = list(options.get("callbacks", [])) if "max_step" in options or "min_step" in options: - max_step = options.get("max_step", np.inf) - min_step = options.get("min_step", 0.0) def __step_cb(ode_sys): ode_sys.dt = D.ar_numpy.clip(ode_sys.dt, min=min_step, max=max_step) - callbacks.insert(0, __step_cb) + callbacks.append(__step_cb) integration_options = dict(callback=callbacks, events=events, eta=options.get("show_prog_bar", False)) if t_eval is None: diff --git a/desolver/tests/test_differential_system.py b/desolver/tests/test_differential_system.py index 4c4b0d4..e75316b 100644 --- a/desolver/tests/test_differential_system.py +++ b/desolver/tests/test_differential_system.py @@ -853,4 +853,17 @@ def fun(t, state, k, m): assert np.allclose(scipy_res.y[...,0], desolver_res.y[...,0], test_tol, test_tol) print(scipy_res.y[...,-1] - desolver_res.y[...,-1]) assert np.allclose(scipy_res.y[...,-1], desolver_res.y[...,-1], test_tol, test_tol) - \ No newline at end of file + + desolver_res = de.solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, min_step=1e-2, method=integrator[0], args=(4.0, 0.1)) + assert np.diff(desolver_res.t)[:-1].min() >= 1e-2 - 1e-8 + + desolver_res = de.solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, max_step=1e-2, method=integrator[0], args=(4.0, 0.1)) + assert np.diff(desolver_res.t)[:-1].max() <= 1e-2 + 1e-8 + + with pytest.raises(ValueError): + t_eval = np.array([-1.0, 0.0, 10.0]) + desolver_res = de.solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, t_eval=t_eval, method=integrator[0], args=(4.0, 0.1)) + + with pytest.raises(ValueError): + t_eval = np.array([0.0, 10.0, 11.0]) + desolver_res = de.solve_ivp(fun, t_span=t_span, y0=y0, atol=atol, rtol=rtol, t_eval=t_eval, method=integrator[0], args=(4.0, 0.1)) \ No newline at end of file