from typing import Optional, Union, Type
import numpy as np
import sympy as sp
from scipy.integrate import solve_ivp, OdeSolver
from pararealml.initial_value_problem import InitialValueProblem
from pararealml.operator import Operator, discretize_time_domain
from pararealml.solution import Solution
[docs]class ODEOperator(Operator):
"""
An ordinary differential equation solver using the SciPy library.
"""
def __init__(
self,
method: Union[str, Type[OdeSolver]],
d_t: float,
first_step: Optional[float] = None,
max_step: float = np.inf,
atol: float = 1e-6,
rtol: float = 1e-3):
"""
:param method: the ODE solver to use
:param d_t: the temporal step size to use
:param first_step: the step size to use for the first time integration
step
:param max_step: the maximum allowed time integration step size
:param atol: the absolute tolerance to use to manage local error
estimates by controlling the time integration step size
:param rtol: the relative tolerance to use to manage local error
estimates by controlling the time integration step size
"""
super(ODEOperator, self).__init__(d_t, None)
self._method = method
self._first_step = first_step
self._max_step = max_step
self._atol = atol
self._rtol = rtol
[docs] def solve(
self,
ivp: InitialValueProblem,
parallel_enabled: bool = True) -> Solution:
diff_eq = ivp.constrained_problem.differential_equation
if diff_eq.x_dimension != 0:
raise ValueError('initial value problem must be an ODE')
t_interval = ivp.t_interval
t = discretize_time_domain(t_interval, self._d_t)
adjusted_t_interval = (t[0], t[-1])
sym = diff_eq.symbols
rhs = diff_eq.symbolic_equation_system.rhs
rhs_lambda = sp.lambdify([sym.t, sym.y], rhs, 'numpy')
def d_y_over_d_t(_t: float, _y: np.ndarray) -> np.ndarray:
return np.asarray(rhs_lambda(_t, _y))
result = solve_ivp(
d_y_over_d_t,
adjusted_t_interval,
ivp.initial_condition.discrete_y_0(),
self._method,
t[1:],
dense_output=False,
vectorized=False,
first_step=self._first_step,
max_step=self._max_step,
atol=self._atol,
rtol=self._rtol)
if not result.success:
raise ValueError(
'error solving initial value problem',
f'status code: {result.status}',
f'message: {result.message}')
y = np.ascontiguousarray(result.y.T)
return Solution(ivp, t[1:], y, d_t=self._d_t)