# systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator(
sde_system,
dt=None,
step_mode=StepMode.FIXED,
backend='jax',
solver='Euler',
sde_type=None,
convergence_type=ConvergenceType.STRONG,
seed=None,
levy_area='none',
**options,
)
```
JAX-based SDE integrator using the Diffrax library.
Provides high-performance SDE integration with automatic differentiation,
JIT compilation, GPU support, and custom noise support via JAX.
## Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|------------------|---------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|
| sde_system | StochasticDynamicalSystem | SDE system to integrate (controlled or autonomous) | _required_ |
| dt | Optional\[float\] | Time step size (required for fixed step, initial guess for adaptive) | `None` |
| step_mode | StepMode | FIXED or ADAPTIVE stepping mode | `StepMode.FIXED` |
| backend | str | Must be 'jax' for this integrator | `'jax'` |
| solver | str | Diffrax SDE solver name (default: 'Euler') Options: 'Euler', 'EulerHeun', 'Heun', 'ItoMilstein', 'SEA', 'SHARK', 'SRA1', 'ReversibleHeun' | `'Euler'` |
| sde_type | Optional\[SDEType\] | SDE interpretation (None = use system's type) | `None` |
| convergence_type | ConvergenceType | Strong or weak convergence | `ConvergenceType.STRONG` |
| seed | Optional\[int\] | Random seed for reproducibility | `None` |
| levy_area | str | Levy area approximation: 'none', 'space-time', or 'full' Required for Milstein-type solvers | `'none'` |
| **options | | Additional options: - rtol : float (default: 1e-3) - Relative tolerance (adaptive only) - atol : float (default: 1e-6) - Absolute tolerance (adaptive only) - max_steps : int (default: 10000) - Maximum steps - adjoint : str ('recursive_checkpoint', 'direct', 'implicit') | `{}` |
## Raises {.doc-section .doc-section-raises}
| Name | Type | Description |
|--------|-------------|---------------------------------|
| | ValueError | If backend is not 'jax' |
| | ImportError | If JAX or Diffrax not installed |
## Notes {.doc-section .doc-section-notes}
- Backend must be 'jax' (Diffrax is JAX-only)
- Supports custom Brownian increments via dW parameter
- JIT compilation happens on first call (may be slow)
- For optimization, use adjoint='recursive_checkpoint'
- Milstein methods require levy_area approximation
## Examples {.doc-section .doc-section-examples}
```python
>>> # Basic usage
>>> integrator = DiffraxSDEIntegrator(
... sde_system,
... dt=0.01,
... solver='Euler'
... )
>>>
>>> # Deterministic testing with zero noise
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.zeros(nw))
>>>
>>> # Custom noise pattern
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.array([0.5]))
>>>
>>> # Optimized for additive noise
>>> integrator = DiffraxSDEIntegrator(
... additive_noise_system,
... dt=0.001,
... solver='SEA' # Specialized for additive noise
... )
```
## Attributes
| Name | Description |
| --- | --- |
| [name](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.name) | Return integrator name. |
## Methods
| Name | Description |
| --- | --- |
| [get_solver_info](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.get_solver_info) | Get information about a specific solver. |
| [integrate](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate) | Integrate SDE over time interval. |
| [integrate_with_gradient](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate_with_gradient) | Integrate and compute gradients w.r.t. initial conditions. |
| [jit_compile](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.jit_compile) | JIT-compile the integration for faster execution. |
| [list_solvers](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.list_solvers) | List available Diffrax SDE solvers by category. |
| [recommend_solver](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.recommend_solver) | Recommend Diffrax solver based on problem characteristics. |
| [step](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.step) | Take one SDE integration step. |
| [to_device](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.to_device) | Move computations to specified device. |
| [vectorized_integrate](#cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.vectorized_integrate) | Vectorized integration over batch of initial conditions. |
### get_solver_info { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.get_solver_info }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.get_solver_info(
solver,
)
```
Get information about a specific solver.
#### Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|--------|--------|---------------|------------|
| solver | str | Solver name | _required_ |
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|------------------|-----------------------------------------------------|
| | Dict\[str, Any\] | Solver properties including required levy_area type |
### integrate { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate(
x0,
u_func,
t_span,
t_eval=None,
dense_output=False,
brownian_path=None,
)
```
Integrate SDE over time interval.
#### Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|---------------|--------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|
| x0 | StateVector | Initial state (nx,) | _required_ |
| u_func | Callable | Control policy: (t, x) -> u (or None for autonomous) | _required_ |
| t_span | TimeSpan | Integration interval (t_start, t_end) | _required_ |
| t_eval | Optional\[TimePoints\] | Specific times at which to save solution | `None` |
| dense_output | bool | If True, enable dense output (not supported for SDEs in Diffrax) | `False` |
| brownian_path | Optional\[CustomBrownianPath\] | Custom Brownian path object (CustomBrownianPath or None) If provided, uses deterministic custom noise instead of random. Must be compatible with Diffrax's AbstractPath interface. | `None` |
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|----------------------|------------------------------------|
| | SDEIntegrationResult | Integration result with trajectory |
#### Examples {.doc-section .doc-section-examples}
```python
>>> # Autonomous SDE with random noise
>>> result = integrator.integrate(
... x0=jnp.array([1.0]),
... u_func=lambda t, x: None,
... t_span=(0.0, 10.0)
... )
>>>
>>> # With custom Brownian path (deterministic)
>>> from cdesym.systems.base.numerical_integration.stochastic.custom_brownian import (
... CustomBrownianPath
... )
>>> dW = jnp.zeros(nw) # Zero noise
>>> brownian = CustomBrownianPath(0.0, 10.0, dW)
>>> result = integrator.integrate(
... x0=jnp.array([1.0]),
... u_func=lambda t, x: None,
... t_span=(0.0, 10.0),
... brownian_path=brownian
... )
>>>
>>> # Controlled SDE
>>> K = jnp.array([[1.0, 2.0]])
>>> result = integrator.integrate(
... x0=jnp.array([1.0, 0.0]),
... u_func=lambda t, x: -K @ x,
... t_span=(0.0, 10.0)
... )
```
### integrate_with_gradient { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate_with_gradient }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.integrate_with_gradient(
x0,
u_func,
t_span,
loss_fn,
t_eval=None,
)
```
Integrate and compute gradients w.r.t. initial conditions.
#### Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|---------|-----------------------|----------------------------------------|------------|
| x0 | ArrayLike | Initial state | _required_ |
| u_func | Callable | Control policy | _required_ |
| t_span | Tuple\[float, float\] | Time span | _required_ |
| loss_fn | Callable | Loss function taking IntegrationResult | _required_ |
| t_eval | Optional\[ArrayLike\] | Evaluation times | `None` |
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|--------|-------------------------------|
| | tuple | (loss_value, gradient_wrt_x0) |
#### Examples {.doc-section .doc-section-examples}
```python
>>> def loss_fn(result):
... return jnp.sum(result.x[-1]**2)
>>>
>>> loss, grad = integrator.integrate_with_gradient(
... x0, u_func, t_span, loss_fn
... )
```
### jit_compile { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.jit_compile }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.jit_compile()
```
JIT-compile the integration for faster execution.
First call will be slow (compilation), subsequent calls fast.
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|----------|-----------------------------------|
| | Callable | JIT-compiled integration function |
#### Examples {.doc-section .doc-section-examples}
```python
>>> jit_integrate = integrator.jit_compile()
>>> result = jit_integrate(x0, u_func, t_span) # Fast!
```
### list_solvers { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.list_solvers }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.list_solvers(
)
```
List available Diffrax SDE solvers by category.
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|--------------------------|-------------------------------|
| | Dict\[str, List\[str\]\] | Solvers organized by category |
#### Examples {.doc-section .doc-section-examples}
```python
>>> solvers = DiffraxSDEIntegrator.list_solvers()
>>> print(solvers['additive_noise'])
['SEA', 'SHARK', 'SRA1']
```
### recommend_solver { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.recommend_solver }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.recommend_solver(
noise_type,
accuracy='medium',
has_derivatives=False,
)
```
Recommend Diffrax solver based on problem characteristics.
#### Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|-----------------|--------|------------------------------------------------|------------|
| noise_type | str | 'additive', 'diagonal', 'scalar', or 'general' | _required_ |
| accuracy | str | 'low', 'medium', or 'high' | `'medium'` |
| has_derivatives | bool | Whether diffusion derivatives are available | `False` |
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|--------|-------------------------|
| | str | Recommended solver name |
#### Examples {.doc-section .doc-section-examples}
```python
>>> solver = DiffraxSDEIntegrator.recommend_solver(
... noise_type='additive',
... accuracy='high'
... )
>>> print(solver)
'SHARK'
```
### step { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.step }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.step(
x,
u=None,
dt=None,
dW=None,
)
```
Take one SDE integration step.
**NEW**: Now supports custom Brownian increments via dW parameter!
#### Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|--------|---------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|
| x | StateVector | Current state (nx,) or (batch, nx) | _required_ |
| u | Optional\[ControlVector\] | Control input (nu,) or (batch, nu), or None for autonomous | `None` |
| dt | Optional\[ScalarLike\] | Step size (uses self.dt if None) | `None` |
| dW | Optional\[NoiseVector\] | Brownian increments (nw,) If provided, uses this deterministic noise instead of random. Shape must match (nw,) where nw is number of Wiener processes. **Key feature**: Providing dW enables: - Deterministic testing: dW=jnp.zeros(nw) - Custom noise patterns: dW=jnp.array([0.5, -0.3]) - Quasi-Monte Carlo methods - Antithetic variates for variance reduction | `None` |
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|-------------|----------------------|
| | StateVector | Next state x(t + dt) |
#### Notes {.doc-section .doc-section-notes}
- Single-step interface less efficient than full integration
- Custom noise (dW) is FULLY SUPPORTED by JAX/Diffrax
- Same dW always gives same result (deterministic)
- This is the RECOMMENDED backend for custom noise needs
#### Examples {.doc-section .doc-section-examples}
```python
>>> # Random noise (default)
>>> x_next = integrator.step(x, u, dt=0.01)
>>>
>>> # Zero noise (deterministic dynamics only)
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.zeros(nw))
>>>
>>> # Custom noise pattern
>>> x_next = integrator.step(x, u, dt=0.01, dW=jnp.array([0.5]))
>>>
>>> # Verify determinism
>>> x1 = integrator.step(x, u, dt=0.01, dW=jnp.zeros(1))
>>> x2 = integrator.step(x, u, dt=0.01, dW=jnp.zeros(1))
>>> assert jnp.allclose(x1, x2) # Passes!
```
### to_device { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.to_device }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.to_device(
device,
)
```
Move computations to specified device.
#### Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|--------|--------|--------------------------------------|------------|
| device | str | 'cpu', 'gpu', 'gpu:0', 'gpu:1', etc. | _required_ |
#### Examples {.doc-section .doc-section-examples}
```python
>>> integrator.to_device('gpu')
>>> # Now all computations use GPU
```
### vectorized_integrate { #cdesym.systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.vectorized_integrate }
```python
systems.base.numerical_integration.stochastic.DiffraxSDEIntegrator.vectorized_integrate(
x0_batch,
u_func,
t_span,
t_eval=None,
)
```
Vectorized integration over batch of initial conditions.
Uses jax.vmap for efficient batching.
#### Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|----------|-----------------------|-------------------------------------|------------|
| x0_batch | ArrayLike | Batch of initial states (batch, nx) | _required_ |
| u_func | Callable | Control policy | _required_ |
| t_span | Tuple\[float, float\] | Time span | _required_ |
| t_eval | Optional\[ArrayLike\] | Evaluation times | `None` |
#### Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|------------------------------|------------------------------------|
| | List\[SDEIntegrationResult\] | Results for each initial condition |
#### Examples {.doc-section .doc-section-examples}
```python
>>> x0_batch = jnp.array([[1.0], [2.0], [3.0]])
>>> results = integrator.vectorized_integrate(
... x0_batch, u_func, (0, 10)
... )
```