systems.base.numerical_integration.DiffraxIntegrator
systems.base.numerical_integration.DiffraxIntegrator(
system,
dt=None,
step_mode=StepMode.ADAPTIVE,
backend='jax',
solver='tsit5',
adjoint='recursive_checkpoint',
**options,
)JAX-based ODE integrator using the Diffrax library.
Supports adaptive and fixed-step integration with various solvers including explicit, implicit (0.8.0+), IMEX (0.8.0+), and special methods.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| system | SymbolicDynamicalSystem | Continuous-time system to integrate (controlled or autonomous) | required |
| dt | Optional[float] | Time step size | None |
| step_mode | StepMode | FIXED or ADAPTIVE stepping mode | StepMode.ADAPTIVE |
| backend | str | Must be ‘jax’ for this integrator | 'jax' |
| solver | str | Solver name. See _solver_map for available options. Default: ‘tsit5’ | 'tsit5' |
| adjoint | str | Adjoint method for backpropagation. Options: ‘recursive_checkpoint’, ‘direct’, ‘implicit’. Default: ‘recursive_checkpoint’ | 'recursive_checkpoint' |
| **options | Additional options including rtol, atol, max_steps | {} |
Available Solvers (Diffrax 0.7.0)
Explicit Runge-Kutta: - tsit5: Tsitouras 5(4) - recommended for most problems - dopri5: Dormand-Prince 5(4) - dopri8: Dormand-Prince 8(7) - high accuracy - bosh3: Bogacki-Shampine 3(2) - euler: Forward Euler (1st order) - midpoint: Explicit midpoint (2nd order) - heun: Heun’s method (2nd order) - ralston: Ralston’s method (2nd order) - reversible_heun: Reversible Heun (2nd order, reversible)
Additional Solvers (Diffrax 0.8.0+)
Implicit methods (for stiff systems): - implicit_euler: Backward Euler (1st order, A-stable) - kvaerno3: Kvaerno 3(2) ESDIRK - kvaerno4: Kvaerno 4(3) ESDIRK - kvaerno5: Kvaerno 5(4) ESDIRK
IMEX methods (for semi-stiff systems): - sil3: 3rd order IMEX - kencarp3: Kennedy-Carpenter IMEX 3 - kencarp4: Kennedy-Carpenter IMEX 4 - kencarp5: Kennedy-Carpenter IMEX 5
Special methods: - semi_implicit_euler: Semi-implicit Euler (symplectic)
Notes
- Implicit and IMEX solvers require Diffrax 0.8.0 or later
- The integrator will automatically detect which solvers are available
- Use
integrator._solver_map.keys()to see available solvers - Supports autonomous systems (nu=0) by passing u=None
- Backward time integration (t_span with tf < t0) is not supported
Examples
>>> # Controlled system
>>> integrator = DiffraxIntegrator(system, backend='jax', solver='tsit5')
>>> result = integrator.integrate(
... x0=jnp.array([1.0, 0.0]),
... u_func=lambda t, x: jnp.array([0.5]),
... t_span=(0.0, 10.0)
... )
>>>
>>> # Autonomous system
>>> integrator = DiffraxIntegrator(autonomous_system, backend='jax')
>>> result = integrator.integrate(
... x0=jnp.array([1.0, 0.0]),
... u_func=lambda t, x: None,
... t_span=(0.0, 10.0)
... )Attributes
| Name | Description |
|---|---|
| name | Return the name of the integrator. |
Methods
| Name | Description |
|---|---|
| integrate | Integrate over time interval with control policy. |
| integrate_imex | Integrate IMEX system with separate explicit and implicit parts. |
| integrate_with_gradient | Integrate and compute gradients w.r.t. initial conditions. |
| jit_compile_step | Return a JIT-compiled version of the step function. |
| step | Take one integration step: x(t) → x(t + dt). |
| vectorized_integrate | Vectorized integration over batch of initial conditions. |
| vectorized_step | Vectorized step over batch of states and controls. |
integrate
systems.base.numerical_integration.DiffraxIntegrator.integrate(
x0,
u_func,
t_span,
t_eval=None,
dense_output=False,
)Integrate over time interval with control policy.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| x0 | StateVector | Initial state (nx,) | required |
| u_func | Callable[[float, StateVector], Optional[ControlVector]] | Control policy: (t, x) → u (or None for autonomous systems) | required |
| t_span | TimeSpan | Integration interval (t_start, t_end) Note: t_end must be > t_start (backward integration not supported) | required |
| t_eval | Optional[TimePoints] | Specific times at which to store solution (must be increasing) | None |
| dense_output | bool | If True, return dense interpolated solution | False |
Returns
| Name | Type | Description |
|---|---|---|
| IntegrationResult | TypedDict containing: - t: Time points (T,) - x: State trajectory (T, nx) - success: Whether integration succeeded - message: Status message - nfev: Number of function evaluations - nsteps: Number of steps taken - integration_time: Computation time (seconds) - solver: Integrator name - njev: Number of Jacobian evaluations (if applicable) |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If t_span has tf < t0 (backward integration not supported) |
Examples
>>> # Controlled system
>>> result = integrator.integrate(
... x0=jnp.array([1.0, 0.0]),
... u_func=lambda t, x: jnp.array([0.5]),
... t_span=(0.0, 10.0)
... )
>>>
>>> # Autonomous system
>>> result = integrator.integrate(
... x0=jnp.array([1.0, 0.0]),
... u_func=lambda t, x: None,
... t_span=(0.0, 10.0)
... )
>>>
>>> # Access results
>>> t = result["t"]
>>> x_traj = result["x"]
>>> print(f"Success: {result['success']}")
>>> print(f"Steps: {result['nsteps']}, Function evals: {result['nfev']}")Notes
Backward time integration (tf < t0) is not supported. If you need reverse-time integration, integrate forward and reverse the output:
Instead of integrate(x0, u_func, (10.0, 0.0))
result = integrate(x0, u_func, (0.0, 10.0)) t_reversed = jnp.flip(result[“t”]) x_reversed = jnp.flip(result[“x”], axis=0)
integrate_imex
systems.base.numerical_integration.DiffraxIntegrator.integrate_imex(
x0,
explicit_func,
implicit_func,
t_span,
t_eval=None,
)Integrate IMEX system with separate explicit and implicit parts.
For systems of the form: dx/dt = f_explicit(t, x) + f_implicit(t, x)
where f_explicit is non-stiff and f_implicit is stiff.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| x0 | StateVector | Initial state | required |
| explicit_func | Callable[[float, StateVector], StateVector] | Non-stiff part: (t, x) -> dx/dt_explicit | required |
| implicit_func | Callable[[float, StateVector], StateVector] | Stiff part: (t, x) -> dx/dt_implicit | required |
| t_span | TimeSpan | Integration interval (must have tf > t0) | required |
| t_eval | Optional[TimePoints] | Evaluation times | None |
Returns
| Name | Type | Description |
|---|---|---|
| IntegrationResult | Integration result with IMEX solver |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If solver is not IMEX, IMEX solvers not available, or tf < t0 |
Notes
IMEX solvers require Diffrax 0.8.0 or later
Examples
>>> # Stiff-nonstiff splitting
>>> def explicit_part(t, x):
... # Non-stiff dynamics
... return jnp.array([x[1], 0.0])
>>>
>>> def implicit_part(t, x):
... # Stiff dynamics
... return jnp.array([0.0, -100*x[1]])
>>>
>>> result = integrator.integrate_imex(
... x0=jnp.array([1.0, 0.0]),
... explicit_func=explicit_part,
... implicit_func=implicit_part,
... t_span=(0.0, 10.0)
... )integrate_with_gradient
systems.base.numerical_integration.DiffraxIntegrator.integrate_with_gradient(
x0,
u_func,
t_span,
loss_fn,
t_eval=None,
)Integrate and compute gradients w.r.t. initial conditions.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| x0 | StateVector | Initial state (nx,) | required |
| u_func | Callable[[float, StateVector], Optional[ControlVector]] | Control policy: (t, x) → u | required |
| t_span | TimeSpan | Integration interval (t_start, t_end) | required |
| loss_fn | Callable[[IntegrationResult], float] | Loss function on integration result | required |
| t_eval | Optional[TimePoints] | Evaluation times | None |
Returns
| Name | Type | Description |
|---|---|---|
| loss | float | Loss value |
| grad | StateVector | Gradient of loss w.r.t. x0 |
Examples
>>> # Define loss (e.g., final state error)
>>> def loss_fn(result):
... x_final = result["x"][-1]
... x_target = jnp.array([1.0, 0.0])
... return jnp.sum((x_final - x_target)**2)
>>>
>>> # Compute loss and gradient
>>> loss, grad = integrator.integrate_with_gradient(
... x0=jnp.array([0.0, 0.0]),
... u_func=lambda t, x: jnp.zeros(1),
... t_span=(0.0, 10.0),
... loss_fn=loss_fn
... )
>>> print(f"Loss: {loss:.4f}")
>>> print(f"Gradient: {grad}")jit_compile_step
systems.base.numerical_integration.DiffraxIntegrator.jit_compile_step()Return a JIT-compiled version of the step function.
Returns
| Name | Type | Description |
|---|---|---|
| jitted_step | callable | JIT-compiled step function with signature: (x: StateVector, u: Optional[ControlVector], dt: float) -> StateVector |
Notes
The JIT-compiled version does not track statistics for performance. Use the regular step() method if you need statistics tracking.
Examples
>>> # Create JIT-compiled step
>>> jit_step = integrator.jit_compile_step()
>>>
>>> # Use for fast repeated stepping
>>> x = jnp.array([1.0, 0.0])
>>> u = jnp.array([0.5])
>>> for _ in range(1000):
... x = jit_step(x, u, 0.01)step
systems.base.numerical_integration.DiffraxIntegrator.step(x, u=None, dt=None)Take one integration step: x(t) → x(t + dt).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| x | StateVector | Current state (nx,) or (batch, nx) | required |
| u | Optional[ControlVector] | Control input (nu,) or (batch, nu), or None for autonomous systems | None |
| dt | Optional[ScalarLike] | Step size (uses self.dt if None) | None |
Returns
| Name | Type | Description |
|---|---|---|
| StateVector | Next state x(t + dt) |
Examples
>>> # Controlled system
>>> x_next = integrator.step(
... x=jnp.array([1.0, 0.0]),
... u=jnp.array([0.5])
... )
>>>
>>> # Autonomous system
>>> x_next = integrator.step(
... x=jnp.array([1.0, 0.0]),
... u=None
... )vectorized_integrate
systems.base.numerical_integration.DiffraxIntegrator.vectorized_integrate(
x0_batch,
u_func,
t_span,
t_eval=None,
)Vectorized integration over batch of initial conditions.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| x0_batch | StateVector | Batched initial states (batch, nx) | required |
| u_func | Callable[[float, StateVector], Optional[ControlVector]] | Control policy: (t, x) → u | required |
| t_span | TimeSpan | Integration interval | required |
| t_eval | Optional[TimePoints] | Evaluation times | None |
Returns
| Name | Type | Description |
|---|---|---|
| list of IntegrationResult | Integration results for each initial condition |
Examples
>>> # Monte Carlo simulation with 100 initial conditions
>>> x0_batch = jnp.random.randn(100, 2)
>>> results = integrator.vectorized_integrate(
... x0_batch,
... lambda t, x: jnp.zeros(1),
... (0.0, 10.0)
... )
>>> print(f"Success rate: {sum(r['success'] for r in results) / 100}")vectorized_step
systems.base.numerical_integration.DiffraxIntegrator.vectorized_step(
x_batch,
u_batch=None,
dt=None,
)Vectorized step over batch of states and controls.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| x_batch | StateVector | Batched states (batch, nx) | required |
| u_batch | Optional[ControlVector] | Batched controls (batch, nu) or None for autonomous | None |
| dt | Optional[ScalarLike] | Step size | None |
Returns
| Name | Type | Description |
|---|---|---|
| StateVector | Next states (batch, nx) |
Examples
>>> # Batch of 100 states
>>> x_batch = jnp.random.randn(100, 2)
>>> u_batch = jnp.zeros((100, 1))
>>> x_next_batch = integrator.vectorized_step(x_batch, u_batch)
>>> print(x_next_batch.shape) # (100, 2)