systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator(
    sde_system,
    dt=None,
    step_mode=StepMode.FIXED,
    backend='jax',
    solver='Euler',
    sde_type=None,
    convergence_type=ConvergenceType.STRONG,
    seed=None,
    levy_area='none',
    **options,
)

JAX-based SDE integrator using the Diffrax library.

Provides high-performance SDE integration with automatic differentiation, JIT compilation, GPU support, and custom noise support via JAX.

Parameters

Name Type Description Default
sde_system StochasticDynamicalSystem SDE system to integrate (controlled or autonomous) required
dt Optional[float] Time step size (required for fixed step, initial guess for adaptive) None
step_mode StepMode FIXED or ADAPTIVE stepping mode StepMode.FIXED
backend str Must be ‘jax’ for this integrator 'jax'
solver str Diffrax SDE solver name (default: ‘Euler’) Options: ‘Euler’, ‘EulerHeun’, ‘Heun’, ‘ItoMilstein’, ‘SEA’, ‘SHARK’, ‘SRA1’, ‘ReversibleHeun’ 'Euler'
sde_type Optional[SDEType] SDE interpretation (None = use system’s type) None
convergence_type ConvergenceType Strong or weak convergence ConvergenceType.STRONG
seed Optional[int] Random seed for reproducibility None
levy_area str Levy area approximation: ‘none’, ‘space-time’, or ‘full’ Required for Milstein-type solvers 'none'
**options Additional options: - rtol : float (default: 1e-3) - Relative tolerance (adaptive only) - atol : float (default: 1e-6) - Absolute tolerance (adaptive only) - max_steps : int (default: 10000) - Maximum steps - adjoint : str (‘recursive_checkpoint’, ‘direct’, ‘implicit’) {}

Raises

Name Type Description
ValueError If backend is not ‘jax’
ImportError If JAX or Diffrax not installed

Notes

  • Backend must be ‘jax’ (Diffrax is JAX-only)
  • Supports custom Brownian increments via dW parameter
  • JIT compilation happens on first call (may be slow)
  • For optimization, use adjoint=‘recursive_checkpoint’
  • Milstein methods require levy_area approximation

Examples

>>> # Basic usage
>>> integrator = DiffraxSDEIntegrator(
...     sde_system,
...     dt=0.01,
...     solver='Euler'
... )
>>>
>>> # Deterministic testing with zero noise
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.zeros(nw))
>>>
>>> # Custom noise pattern
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.array([0.5]))
>>>
>>> # Optimized for additive noise
>>> integrator = DiffraxSDEIntegrator(
...     additive_noise_system,
...     dt=0.001,
...     solver='SEA'  # Specialized for additive noise
... )

Attributes

Name Description
name Return integrator name.

Methods

Name Description
get_solver_info Get information about a specific solver.
integrate Integrate SDE over time interval.
integrate_with_gradient Integrate and compute gradients w.r.t. initial conditions.
jit_compile JIT-compile the integration for faster execution.
list_solvers List available Diffrax SDE solvers by category.
recommend_solver Recommend Diffrax solver based on problem characteristics.
step Take one SDE integration step.
to_device Move computations to specified device.
vectorized_integrate Vectorized integration over batch of initial conditions.

get_solver_info

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.get_solver_info(
    solver,
)

Get information about a specific solver.

Parameters

Name Type Description Default
solver str Solver name required

Returns

Name Type Description
Dict[str, Any] Solver properties including required levy_area type

integrate

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate(
    x0,
    u_func,
    t_span,
    t_eval=None,
    dense_output=False,
    brownian_path=None,
)

Integrate SDE over time interval.

Parameters

Name Type Description Default
x0 StateVector Initial state (nx,) required
u_func Callable Control policy: (t, x) -> u (or None for autonomous) required
t_span TimeSpan Integration interval (t_start, t_end) required
t_eval Optional[TimePoints] Specific times at which to save solution None
dense_output bool If True, enable dense output (not supported for SDEs in Diffrax) False
brownian_path Optional[CustomBrownianPath] Custom Brownian path object (CustomBrownianPath or None) If provided, uses deterministic custom noise instead of random. Must be compatible with Diffrax’s AbstractPath interface. None

Returns

Name Type Description
SDEIntegrationResult Integration result with trajectory

Examples

>>> # Autonomous SDE with random noise
>>> result = integrator.integrate(
...     x0=jnp.array([1.0]),
...     u_func=lambda t, x: None,
...     t_span=(0.0, 10.0)
... )
>>>
>>> # With custom Brownian path (deterministic)
>>> from cdesym.systems.base.numerical_integration.stochastic.custom_brownian import (
...     CustomBrownianPath
... )
>>> dW = jnp.zeros(nw)  # Zero noise
>>> brownian = CustomBrownianPath(0.0, 10.0, dW)
>>> result = integrator.integrate(
...     x0=jnp.array([1.0]),
...     u_func=lambda t, x: None,
...     t_span=(0.0, 10.0),
...     brownian_path=brownian
... )
>>>
>>> # Controlled SDE
>>> K = jnp.array([[1.0, 2.0]])
>>> result = integrator.integrate(
...     x0=jnp.array([1.0, 0.0]),
...     u_func=lambda t, x: -K @ x,
...     t_span=(0.0, 10.0)
... )

integrate_with_gradient

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate_with_gradient(
    x0,
    u_func,
    t_span,
    loss_fn,
    t_eval=None,
)

Integrate and compute gradients w.r.t. initial conditions.

Parameters

Name Type Description Default
x0 ArrayLike Initial state required
u_func Callable Control policy required
t_span Tuple[float, float] Time span required
loss_fn Callable Loss function taking IntegrationResult required
t_eval Optional[ArrayLike] Evaluation times None

Returns

Name Type Description
tuple (loss_value, gradient_wrt_x0)

Examples

>>> def loss_fn(result):
...     return jnp.sum(result.x[-1]**2)
>>>
>>> loss, grad = integrator.integrate_with_gradient(
...     x0, u_func, t_span, loss_fn
... )

jit_compile

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.jit_compile()

JIT-compile the integration for faster execution.

First call will be slow (compilation), subsequent calls fast.

Returns

Name Type Description
Callable JIT-compiled integration function

Examples

>>> jit_integrate = integrator.jit_compile()
>>> result = jit_integrate(x0, u_func, t_span)  # Fast!

list_solvers

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.list_solvers(
)

List available Diffrax SDE solvers by category.

Returns

Name Type Description
Dict[str, List[str]] Solvers organized by category

Examples

>>> solvers = DiffraxSDEIntegrator.list_solvers()
>>> print(solvers['additive_noise'])
['SEA', 'SHARK', 'SRA1']

recommend_solver

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.recommend_solver(
    noise_type,
    accuracy='medium',
    has_derivatives=False,
)

Recommend Diffrax solver based on problem characteristics.

Parameters

Name Type Description Default
noise_type str ‘additive’, ‘diagonal’, ‘scalar’, or ‘general’ required
accuracy str ‘low’, ‘medium’, or ‘high’ 'medium'
has_derivatives bool Whether diffusion derivatives are available False

Returns

Name Type Description
str Recommended solver name

Examples

>>> solver = DiffraxSDEIntegrator.recommend_solver(
...     noise_type='additive',
...     accuracy='high'
... )
>>> print(solver)
'SHARK'

step

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.step(
    x,
    u=None,
    dt=None,
    dW=None,
)

Take one SDE integration step.

NEW: Now supports custom Brownian increments via dW parameter!

Parameters

Name Type Description Default
x StateVector Current state (nx,) or (batch, nx) required
u Optional[ControlVector] Control input (nu,) or (batch, nu), or None for autonomous None
dt Optional[ScalarLike] Step size (uses self.dt if None) None
dW Optional[NoiseVector] Brownian increments (nw,) If provided, uses this deterministic noise instead of random. Shape must match (nw,) where nw is number of Wiener processes. Key feature: Providing dW enables: - Deterministic testing: dW=jnp.zeros(nw) - Custom noise patterns: dW=jnp.array([0.5, -0.3]) - Quasi-Monte Carlo methods - Antithetic variates for variance reduction None

Returns

Name Type Description
StateVector Next state x(t + dt)

Notes

  • Single-step interface less efficient than full integration
  • Custom noise (dW) is FULLY SUPPORTED by JAX/Diffrax
  • Same dW always gives same result (deterministic)
  • This is the RECOMMENDED backend for custom noise needs

Examples

>>> # Random noise (default)
>>> x_next = integrator.step(x, u, dt=0.01)
>>>
>>> # Zero noise (deterministic dynamics only)
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.zeros(nw))
>>>
>>> # Custom noise pattern
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.array([0.5]))
>>>
>>> # Verify determinism
>>> x1 = integrator.step(x, u, dt=0.01, dW=jnp.zeros(1))
>>> x2 = integrator.step(x, u, dt=0.01, dW=jnp.zeros(1))
>>> assert jnp.allclose(x1, x2)  # Passes!

to_device

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.to_device(
    device,
)

Move computations to specified device.

Parameters

Name Type Description Default
device str ‘cpu’, ‘gpu’, ‘gpu:0’, ‘gpu:1’, etc. required

Examples

>>> integrator.to_device('gpu')
>>> # Now all computations use GPU

vectorized_integrate

systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.vectorized_integrate(
    x0_batch,
    u_func,
    t_span,
    t_eval=None,
)

Vectorized integration over batch of initial conditions.

Uses jax.vmap for efficient batching.

Parameters

Name Type Description Default
x0_batch ArrayLike Batch of initial states (batch, nx) required
u_func Callable Control policy required
t_span Tuple[float, float] Time span required
t_eval Optional[ArrayLike] Evaluation times None

Returns

Name Type Description
List[SDEIntegrationResult] Results for each initial condition

Examples

>>> x0_batch = jnp.array([[1.0], [2.0], [3.0]])
>>> results = integrator.vectorized_integrate(
...     x0_batch, u_func, (0, 10)
... )