systems.base.numerical_integration.DiffraxIntegrator

systems.base.numerical_integration.DiffraxIntegrator(
    system,
    dt=None,
    step_mode=StepMode.ADAPTIVE,
    backend='jax',
    solver='tsit5',
    adjoint='recursive_checkpoint',
    **options,
)

JAX-based ODE integrator using the Diffrax library.

Supports adaptive and fixed-step integration with various solvers including explicit, implicit (0.8.0+), IMEX (0.8.0+), and special methods.

Parameters

Name Type Description Default
system SymbolicDynamicalSystem Continuous-time system to integrate (controlled or autonomous) required
dt Optional[float] Time step size None
step_mode StepMode FIXED or ADAPTIVE stepping mode StepMode.ADAPTIVE
backend str Must be ‘jax’ for this integrator 'jax'
solver str Solver name. See _solver_map for available options. Default: ‘tsit5’ 'tsit5'
adjoint str Adjoint method for backpropagation. Options: ‘recursive_checkpoint’, ‘direct’, ‘implicit’. Default: ‘recursive_checkpoint’ 'recursive_checkpoint'
**options Additional options including rtol, atol, max_steps {}

Available Solvers (Diffrax 0.7.0)

Explicit Runge-Kutta: - tsit5: Tsitouras 5(4) - recommended for most problems - dopri5: Dormand-Prince 5(4) - dopri8: Dormand-Prince 8(7) - high accuracy - bosh3: Bogacki-Shampine 3(2) - euler: Forward Euler (1st order) - midpoint: Explicit midpoint (2nd order) - heun: Heun’s method (2nd order) - ralston: Ralston’s method (2nd order) - reversible_heun: Reversible Heun (2nd order, reversible)

Additional Solvers (Diffrax 0.8.0+)

Implicit methods (for stiff systems): - implicit_euler: Backward Euler (1st order, A-stable) - kvaerno3: Kvaerno 3(2) ESDIRK - kvaerno4: Kvaerno 4(3) ESDIRK - kvaerno5: Kvaerno 5(4) ESDIRK

IMEX methods (for semi-stiff systems): - sil3: 3rd order IMEX - kencarp3: Kennedy-Carpenter IMEX 3 - kencarp4: Kennedy-Carpenter IMEX 4 - kencarp5: Kennedy-Carpenter IMEX 5

Special methods: - semi_implicit_euler: Semi-implicit Euler (symplectic)

Notes

  • Implicit and IMEX solvers require Diffrax 0.8.0 or later
  • The integrator will automatically detect which solvers are available
  • Use integrator._solver_map.keys() to see available solvers
  • Supports autonomous systems (nu=0) by passing u=None
  • Backward time integration (t_span with tf < t0) is not supported

Examples

>>> # Controlled system
>>> integrator = DiffraxIntegrator(system, backend='jax', solver='tsit5')
>>> result = integrator.integrate(
...     x0=jnp.array([1.0, 0.0]),
...     u_func=lambda t, x: jnp.array([0.5]),
...     t_span=(0.0, 10.0)
... )
>>>
>>> # Autonomous system
>>> integrator = DiffraxIntegrator(autonomous_system, backend='jax')
>>> result = integrator.integrate(
...     x0=jnp.array([1.0, 0.0]),
...     u_func=lambda t, x: None,
...     t_span=(0.0, 10.0)
... )

Attributes

Name Description
name Return the name of the integrator.

Methods

Name Description
integrate Integrate over time interval with control policy.
integrate_imex Integrate IMEX system with separate explicit and implicit parts.
integrate_with_gradient Integrate and compute gradients w.r.t. initial conditions.
jit_compile_step Return a JIT-compiled version of the step function.
step Take one integration step: x(t) → x(t + dt).
vectorized_integrate Vectorized integration over batch of initial conditions.
vectorized_step Vectorized step over batch of states and controls.

integrate

systems.base.numerical_integration.DiffraxIntegrator.integrate(
    x0,
    u_func,
    t_span,
    t_eval=None,
    dense_output=False,
)

Integrate over time interval with control policy.

Parameters

Name Type Description Default
x0 StateVector Initial state (nx,) required
u_func Callable[[float, StateVector], Optional[ControlVector]] Control policy: (t, x) → u (or None for autonomous systems) required
t_span TimeSpan Integration interval (t_start, t_end) Note: t_end must be > t_start (backward integration not supported) required
t_eval Optional[TimePoints] Specific times at which to store solution (must be increasing) None
dense_output bool If True, return dense interpolated solution False

Returns

Name Type Description
IntegrationResult TypedDict containing: - t: Time points (T,) - x: State trajectory (T, nx) - success: Whether integration succeeded - message: Status message - nfev: Number of function evaluations - nsteps: Number of steps taken - integration_time: Computation time (seconds) - solver: Integrator name - njev: Number of Jacobian evaluations (if applicable)

Raises

Name Type Description
ValueError If t_span has tf < t0 (backward integration not supported)

Examples

>>> # Controlled system
>>> result = integrator.integrate(
...     x0=jnp.array([1.0, 0.0]),
...     u_func=lambda t, x: jnp.array([0.5]),
...     t_span=(0.0, 10.0)
... )
>>>
>>> # Autonomous system
>>> result = integrator.integrate(
...     x0=jnp.array([1.0, 0.0]),
...     u_func=lambda t, x: None,
...     t_span=(0.0, 10.0)
... )
>>>
>>> # Access results
>>> t = result["t"]
>>> x_traj = result["x"]
>>> print(f"Success: {result['success']}")
>>> print(f"Steps: {result['nsteps']}, Function evals: {result['nfev']}")

Notes

Backward time integration (tf < t0) is not supported. If you need reverse-time integration, integrate forward and reverse the output:

Instead of integrate(x0, u_func, (10.0, 0.0))

result = integrate(x0, u_func, (0.0, 10.0)) t_reversed = jnp.flip(result[“t”]) x_reversed = jnp.flip(result[“x”], axis=0)

integrate_imex

systems.base.numerical_integration.DiffraxIntegrator.integrate_imex(
    x0,
    explicit_func,
    implicit_func,
    t_span,
    t_eval=None,
)

Integrate IMEX system with separate explicit and implicit parts.

For systems of the form: dx/dt = f_explicit(t, x) + f_implicit(t, x)

where f_explicit is non-stiff and f_implicit is stiff.

Parameters

Name Type Description Default
x0 StateVector Initial state required
explicit_func Callable[[float, StateVector], StateVector] Non-stiff part: (t, x) -> dx/dt_explicit required
implicit_func Callable[[float, StateVector], StateVector] Stiff part: (t, x) -> dx/dt_implicit required
t_span TimeSpan Integration interval (must have tf > t0) required
t_eval Optional[TimePoints] Evaluation times None

Returns

Name Type Description
IntegrationResult Integration result with IMEX solver

Raises

Name Type Description
ValueError If solver is not IMEX, IMEX solvers not available, or tf < t0

Notes

IMEX solvers require Diffrax 0.8.0 or later

Examples

>>> # Stiff-nonstiff splitting
>>> def explicit_part(t, x):
...     # Non-stiff dynamics
...     return jnp.array([x[1], 0.0])
>>>
>>> def implicit_part(t, x):
...     # Stiff dynamics
...     return jnp.array([0.0, -100*x[1]])
>>>
>>> result = integrator.integrate_imex(
...     x0=jnp.array([1.0, 0.0]),
...     explicit_func=explicit_part,
...     implicit_func=implicit_part,
...     t_span=(0.0, 10.0)
... )

integrate_with_gradient

systems.base.numerical_integration.DiffraxIntegrator.integrate_with_gradient(
    x0,
    u_func,
    t_span,
    loss_fn,
    t_eval=None,
)

Integrate and compute gradients w.r.t. initial conditions.

Parameters

Name Type Description Default
x0 StateVector Initial state (nx,) required
u_func Callable[[float, StateVector], Optional[ControlVector]] Control policy: (t, x) → u required
t_span TimeSpan Integration interval (t_start, t_end) required
loss_fn Callable[[IntegrationResult], float] Loss function on integration result required
t_eval Optional[TimePoints] Evaluation times None

Returns

Name Type Description
loss float Loss value
grad StateVector Gradient of loss w.r.t. x0

Examples

>>> # Define loss (e.g., final state error)
>>> def loss_fn(result):
...     x_final = result["x"][-1]
...     x_target = jnp.array([1.0, 0.0])
...     return jnp.sum((x_final - x_target)**2)
>>>
>>> # Compute loss and gradient
>>> loss, grad = integrator.integrate_with_gradient(
...     x0=jnp.array([0.0, 0.0]),
...     u_func=lambda t, x: jnp.zeros(1),
...     t_span=(0.0, 10.0),
...     loss_fn=loss_fn
... )
>>> print(f"Loss: {loss:.4f}")
>>> print(f"Gradient: {grad}")

jit_compile_step

systems.base.numerical_integration.DiffraxIntegrator.jit_compile_step()

Return a JIT-compiled version of the step function.

Returns

Name Type Description
jitted_step callable JIT-compiled step function with signature: (x: StateVector, u: Optional[ControlVector], dt: float) -> StateVector

Notes

The JIT-compiled version does not track statistics for performance. Use the regular step() method if you need statistics tracking.

Examples

>>> # Create JIT-compiled step
>>> jit_step = integrator.jit_compile_step()
>>>
>>> # Use for fast repeated stepping
>>> x = jnp.array([1.0, 0.0])
>>> u = jnp.array([0.5])
>>> for _ in range(1000):
...     x = jit_step(x, u, 0.01)

step

systems.base.numerical_integration.DiffraxIntegrator.step(x, u=None, dt=None)

Take one integration step: x(t) → x(t + dt).

Parameters

Name Type Description Default
x StateVector Current state (nx,) or (batch, nx) required
u Optional[ControlVector] Control input (nu,) or (batch, nu), or None for autonomous systems None
dt Optional[ScalarLike] Step size (uses self.dt if None) None

Returns

Name Type Description
StateVector Next state x(t + dt)

Examples

>>> # Controlled system
>>> x_next = integrator.step(
...     x=jnp.array([1.0, 0.0]),
...     u=jnp.array([0.5])
... )
>>>
>>> # Autonomous system
>>> x_next = integrator.step(
...     x=jnp.array([1.0, 0.0]),
...     u=None
... )

vectorized_integrate

systems.base.numerical_integration.DiffraxIntegrator.vectorized_integrate(
    x0_batch,
    u_func,
    t_span,
    t_eval=None,
)

Vectorized integration over batch of initial conditions.

Parameters

Name Type Description Default
x0_batch StateVector Batched initial states (batch, nx) required
u_func Callable[[float, StateVector], Optional[ControlVector]] Control policy: (t, x) → u required
t_span TimeSpan Integration interval required
t_eval Optional[TimePoints] Evaluation times None

Returns

Name Type Description
list of IntegrationResult Integration results for each initial condition

Examples

>>> # Monte Carlo simulation with 100 initial conditions
>>> x0_batch = jnp.random.randn(100, 2)
>>> results = integrator.vectorized_integrate(
...     x0_batch,
...     lambda t, x: jnp.zeros(1),
...     (0.0, 10.0)
... )
>>> print(f"Success rate: {sum(r['success'] for r in results) / 100}")

vectorized_step

systems.base.numerical_integration.DiffraxIntegrator.vectorized_step(
    x_batch,
    u_batch=None,
    dt=None,
)

Vectorized step over batch of states and controls.

Parameters

Name Type Description Default
x_batch StateVector Batched states (batch, nx) required
u_batch Optional[ControlVector] Batched controls (batch, nu) or None for autonomous None
dt Optional[ScalarLike] Step size None

Returns

Name Type Description
StateVector Next states (batch, nx)

Examples

>>> # Batch of 100 states
>>> x_batch = jnp.random.randn(100, 2)
>>> u_batch = jnp.zeros((100, 1))
>>> x_next_batch = integrator.vectorized_step(x_batch, u_batch)
>>> print(x_next_batch.shape)  # (100, 2)