systems.base.utils.stochastic.DiffusionHandler

systems.base.utils.stochastic.DiffusionHandler(
    diffusion_expr,
    state_vars,
    control_vars,
    time_var=None,
    parameters=None,
)

Handles code generation and caching for diffusion terms.

Responsibilities: - Generate backend-specific g(x, u) functions via codegen_utils - Cache generated functions (avoid recompilation) - Compose with NoiseCharacterizer for automatic analysis - Provide constant noise matrices for additive case - Track generation statistics

This class mirrors CodeGenerator’s interface and patterns for consistency.

Examples

>>> from sympy import symbols, Matrix
>>> x1, x2, u = symbols('x1 x2 u')
>>>
>>> # State-independent (additive) noise
>>> diffusion = Matrix([[0.1], [0.2]])
>>> handler = DiffusionHandler(diffusion, [x1, x2], [u])
>>>
>>> # Automatic analysis
>>> print(handler.characteristics.noise_type)
NoiseType.ADDITIVE
>>>
>>> print(handler.characteristics.recommended_solvers('jax'))
['sea', 'shark', 'sra1']
>>>
>>> # Generate function with caching
>>> g_numpy = handler.generate_function('numpy')
>>> g_val = g_numpy(1.0, 0.5, 0.0)  # g(x1=1.0, x2=0.5, u=0.0)
>>>
>>> # Second call returns cached
>>> g_numpy_cached = handler.generate_function('numpy')
>>> assert g_numpy is g_numpy_cached
>>>
>>> # For additive noise, get constant matrix
>>> if handler.characteristics.is_additive:
...     G = handler.get_constant_noise('numpy')
...     print(G)  # [[0.1], [0.2]] - precomputed!

Attributes

Name Description
characteristics Get noise characteristics from automatic analysis.

Methods

Name Description
can_optimize_for_additive Check if additive-noise optimizations are applicable.
compile_all Pre-compile diffusion functions for multiple backends.
generate_function Generate g(x, u) function for specified backend.
get_constant_noise Get constant noise matrix for additive noise.
get_function Get cached diffusion function without generating.
get_info Get comprehensive information about diffusion handler state.
get_optimization_opportunities Identify optimization opportunities based on noise structure.
get_stats Get generation statistics.
has_constant_noise Check if constant noise has been computed and cached.
is_compiled Check if diffusion function is compiled for backend.
reset_cache Clear cached functions for specified backends.
reset_stats Reset generation statistics.
warmup Warm up function compilation (especially useful for JAX JIT).

can_optimize_for_additive

systems.base.utils.stochastic.DiffusionHandler.can_optimize_for_additive()

Check if additive-noise optimizations are applicable.

Returns

Name Type Description
bool True if diffusion is constant and can be precomputed

compile_all

systems.base.utils.stochastic.DiffusionHandler.compile_all(
    backends=None,
    verbose=False,
    **kwargs,
)

Pre-compile diffusion functions for multiple backends.

Mirrors CodeGenerator.compile_all() for consistency.

Parameters

Name Type Description Default
backends List[Backend] Backends to compile (None = all available) None
verbose bool Print compilation progress False
**kwargs Backend-specific options {}

Returns

Name Type Description
Dict[str, float] Mapping backend → compilation_time

Examples

>>> timings = handler.compile_all(verbose=True)
Compiling diffusion for numpy: 0.05s
Compiling diffusion for torch: 0.12s
Compiling diffusion for jax: 0.08s
>>>
>>> print(timings)
{'numpy': 0.05, 'torch': 0.12, 'jax': 0.08}

generate_function

systems.base.utils.stochastic.DiffusionHandler.generate_function(
    backend,
    **kwargs,
)

Generate g(x, u) function for specified backend.

Parameters

Name Type Description Default
backend Backend Target backend (‘numpy’, ‘torch’, ‘jax’) required
**kwargs Additional arguments for generate_function {}

Returns

Name Type Description
DiffusionFunction Callable g(x, u) → diffusion matrix

get_constant_noise

systems.base.utils.stochastic.DiffusionHandler.get_constant_noise(
    backend='numpy',
)

Get constant noise matrix for additive noise.

For additive noise, diffusion doesn’t depend on state, control, or time, so it can be precomputed once and reused. This is a significant performance optimization for SDE solving.

Parameters

Name Type Description Default
backend Backend Backend for array type 'numpy'

Returns

Name Type Description
DiffusionMatrix Constant diffusion matrix, shape (nx, nw)

Raises

Name Type Description
ValueError If noise is not additive (use generate_function instead)

Examples

>>> if handler.characteristics.is_additive:
...     G = handler.get_constant_noise('numpy')
...     # G is constant - precompute once, reuse everywhere!
...     for t in range(1000):
...         # No need to evaluate diffusion - just use G
...         pass

get_function

systems.base.utils.stochastic.DiffusionHandler.get_function(backend)

Get cached diffusion function without generating.

Mirrors CodeGenerator.get_dynamics().

Parameters

Name Type Description Default
backend Backend Target backend required

Returns

Name Type Description
Optional[DiffusionFunction] Cached function or None if not yet generated

Examples

>>> g = handler.get_function('numpy')
>>> if g is None:
...     g = handler.generate_function('numpy')

get_info

systems.base.utils.stochastic.DiffusionHandler.get_info()

Get comprehensive information about diffusion handler state.

Mirrors CodeGenerator.get_info() for consistent API.

Returns

Name Type Description
Dict[str, any] Status, characteristics, compilation state, statistics

Examples

>>> info = handler.get_info()
>>> print(info)
{
    'dimensions': {'nx': 2, 'nw': 1},
    'noise_type': 'additive',
    'characteristics': {...},
    'compiled': {'numpy': True, 'torch': False, 'jax': True},
    'constant_noise_cached': True,
    'statistics': {...}
}

get_optimization_opportunities

systems.base.utils.stochastic.DiffusionHandler.get_optimization_opportunities()

Identify optimization opportunities based on noise structure.

Returns

Name Type Description
Dict[str, bool] Flags for various optimization strategies

Examples

>>> opts = handler.get_optimization_opportunities()
>>> if opts['precompute_diffusion']:
...     G = handler.get_constant_noise()
...     # Use G directly instead of calling diffusion(x, u)

get_stats

systems.base.utils.stochastic.DiffusionHandler.get_stats()

Get generation statistics.

Returns

Name Type Description
Dict[str, any] Generation count, cache hits, timing

Examples

>>> stats = handler.get_stats()
>>> print(f"Cache hit rate: {stats['cache_hit_rate']:.1%}")

has_constant_noise

systems.base.utils.stochastic.DiffusionHandler.has_constant_noise()

Check if constant noise has been computed and cached.

Returns

Name Type Description
bool True if any backend has cached constant noise

is_compiled

systems.base.utils.stochastic.DiffusionHandler.is_compiled(backend)

Check if diffusion function is compiled for backend.

Mirrors CodeGenerator.is_compiled() but simpler (only one function type).

Parameters

Name Type Description Default
backend Backend Backend to check required

Returns

Name Type Description
bool True if function is cached/compiled

Examples

>>> if not handler.is_compiled('jax'):
...     handler.generate_function('jax')

reset_cache

systems.base.utils.stochastic.DiffusionHandler.reset_cache(backends=None)

Clear cached functions for specified backends.

Mirrors CodeGenerator.reset_cache().

Parameters

Name Type Description Default
backends List[Backend] Backends to reset (None = all) None

Examples

>>> # Clear all caches
>>> handler.reset_cache()
>>>
>>> # Clear only torch (e.g., after device change)
>>> handler.reset_cache(['torch'])

reset_stats

systems.base.utils.stochastic.DiffusionHandler.reset_stats()

Reset generation statistics.

Examples

>>> handler.reset_stats()
>>> stats = handler.get_stats()
>>> assert stats['generations'] == 0

warmup

systems.base.utils.stochastic.DiffusionHandler.warmup(backend, **kwargs)

Warm up function compilation (especially useful for JAX JIT).

Parameters

Name Type Description Default
backend Backend Backend to warm up required
**kwargs Test inputs for warmup call {}

Examples

>>> # JAX JIT warmup
>>> handler.warmup('jax')
>>> # First call compiled, subsequent calls fast