systems.base.utils.stochastic.DiffusionHandler
systems.base.utils.stochastic.DiffusionHandler(
diffusion_expr,
state_vars,
control_vars,
time_var=None,
parameters=None,
)Handles code generation and caching for diffusion terms.
Responsibilities: - Generate backend-specific g(x, u) functions via codegen_utils - Cache generated functions (avoid recompilation) - Compose with NoiseCharacterizer for automatic analysis - Provide constant noise matrices for additive case - Track generation statistics
This class mirrors CodeGenerator’s interface and patterns for consistency.
Examples
>>> from sympy import symbols, Matrix
>>> x1, x2, u = symbols('x1 x2 u')
>>>
>>> # State-independent (additive) noise
>>> diffusion = Matrix([[0.1], [0.2]])
>>> handler = DiffusionHandler(diffusion, [x1, x2], [u])
>>>
>>> # Automatic analysis
>>> print(handler.characteristics.noise_type)
NoiseType.ADDITIVE
>>>
>>> print(handler.characteristics.recommended_solvers('jax'))
['sea', 'shark', 'sra1']
>>>
>>> # Generate function with caching
>>> g_numpy = handler.generate_function('numpy')
>>> g_val = g_numpy(1.0, 0.5, 0.0) # g(x1=1.0, x2=0.5, u=0.0)
>>>
>>> # Second call returns cached
>>> g_numpy_cached = handler.generate_function('numpy')
>>> assert g_numpy is g_numpy_cached
>>>
>>> # For additive noise, get constant matrix
>>> if handler.characteristics.is_additive:
... G = handler.get_constant_noise('numpy')
... print(G) # [[0.1], [0.2]] - precomputed!Attributes
| Name | Description |
|---|---|
| characteristics | Get noise characteristics from automatic analysis. |
Methods
| Name | Description |
|---|---|
| can_optimize_for_additive | Check if additive-noise optimizations are applicable. |
| compile_all | Pre-compile diffusion functions for multiple backends. |
| generate_function | Generate g(x, u) function for specified backend. |
| get_constant_noise | Get constant noise matrix for additive noise. |
| get_function | Get cached diffusion function without generating. |
| get_info | Get comprehensive information about diffusion handler state. |
| get_optimization_opportunities | Identify optimization opportunities based on noise structure. |
| get_stats | Get generation statistics. |
| has_constant_noise | Check if constant noise has been computed and cached. |
| is_compiled | Check if diffusion function is compiled for backend. |
| reset_cache | Clear cached functions for specified backends. |
| reset_stats | Reset generation statistics. |
| warmup | Warm up function compilation (especially useful for JAX JIT). |
can_optimize_for_additive
systems.base.utils.stochastic.DiffusionHandler.can_optimize_for_additive()Check if additive-noise optimizations are applicable.
Returns
| Name | Type | Description |
|---|---|---|
| bool | True if diffusion is constant and can be precomputed |
compile_all
systems.base.utils.stochastic.DiffusionHandler.compile_all(
backends=None,
verbose=False,
**kwargs,
)Pre-compile diffusion functions for multiple backends.
Mirrors CodeGenerator.compile_all() for consistency.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| backends | List[Backend] | Backends to compile (None = all available) | None |
| verbose | bool | Print compilation progress | False |
| **kwargs | Backend-specific options | {} |
Returns
| Name | Type | Description |
|---|---|---|
| Dict[str, float] | Mapping backend → compilation_time |
Examples
>>> timings = handler.compile_all(verbose=True)
Compiling diffusion for numpy: 0.05s
Compiling diffusion for torch: 0.12s
Compiling diffusion for jax: 0.08s
>>>
>>> print(timings)
{'numpy': 0.05, 'torch': 0.12, 'jax': 0.08}generate_function
systems.base.utils.stochastic.DiffusionHandler.generate_function(
backend,
**kwargs,
)Generate g(x, u) function for specified backend.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| backend | Backend | Target backend (‘numpy’, ‘torch’, ‘jax’) | required |
| **kwargs | Additional arguments for generate_function | {} |
Returns
| Name | Type | Description |
|---|---|---|
| DiffusionFunction | Callable g(x, u) → diffusion matrix |
get_constant_noise
systems.base.utils.stochastic.DiffusionHandler.get_constant_noise(
backend='numpy',
)Get constant noise matrix for additive noise.
For additive noise, diffusion doesn’t depend on state, control, or time, so it can be precomputed once and reused. This is a significant performance optimization for SDE solving.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| backend | Backend | Backend for array type | 'numpy' |
Returns
| Name | Type | Description |
|---|---|---|
| DiffusionMatrix | Constant diffusion matrix, shape (nx, nw) |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If noise is not additive (use generate_function instead) |
Examples
>>> if handler.characteristics.is_additive:
... G = handler.get_constant_noise('numpy')
... # G is constant - precompute once, reuse everywhere!
... for t in range(1000):
... # No need to evaluate diffusion - just use G
... passget_function
systems.base.utils.stochastic.DiffusionHandler.get_function(backend)Get cached diffusion function without generating.
Mirrors CodeGenerator.get_dynamics().
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| backend | Backend | Target backend | required |
Returns
| Name | Type | Description |
|---|---|---|
| Optional[DiffusionFunction] | Cached function or None if not yet generated |
Examples
>>> g = handler.get_function('numpy')
>>> if g is None:
... g = handler.generate_function('numpy')get_info
systems.base.utils.stochastic.DiffusionHandler.get_info()Get comprehensive information about diffusion handler state.
Mirrors CodeGenerator.get_info() for consistent API.
Returns
| Name | Type | Description |
|---|---|---|
| Dict[str, any] | Status, characteristics, compilation state, statistics |
Examples
>>> info = handler.get_info()
>>> print(info)
{
'dimensions': {'nx': 2, 'nw': 1},
'noise_type': 'additive',
'characteristics': {...},
'compiled': {'numpy': True, 'torch': False, 'jax': True},
'constant_noise_cached': True,
'statistics': {...}
}get_optimization_opportunities
systems.base.utils.stochastic.DiffusionHandler.get_optimization_opportunities()Identify optimization opportunities based on noise structure.
Returns
| Name | Type | Description |
|---|---|---|
| Dict[str, bool] | Flags for various optimization strategies |
Examples
>>> opts = handler.get_optimization_opportunities()
>>> if opts['precompute_diffusion']:
... G = handler.get_constant_noise()
... # Use G directly instead of calling diffusion(x, u)get_stats
systems.base.utils.stochastic.DiffusionHandler.get_stats()Get generation statistics.
Returns
| Name | Type | Description |
|---|---|---|
| Dict[str, any] | Generation count, cache hits, timing |
Examples
>>> stats = handler.get_stats()
>>> print(f"Cache hit rate: {stats['cache_hit_rate']:.1%}")has_constant_noise
systems.base.utils.stochastic.DiffusionHandler.has_constant_noise()Check if constant noise has been computed and cached.
Returns
| Name | Type | Description |
|---|---|---|
| bool | True if any backend has cached constant noise |
is_compiled
systems.base.utils.stochastic.DiffusionHandler.is_compiled(backend)Check if diffusion function is compiled for backend.
Mirrors CodeGenerator.is_compiled() but simpler (only one function type).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| backend | Backend | Backend to check | required |
Returns
| Name | Type | Description |
|---|---|---|
| bool | True if function is cached/compiled |
Examples
>>> if not handler.is_compiled('jax'):
... handler.generate_function('jax')reset_cache
systems.base.utils.stochastic.DiffusionHandler.reset_cache(backends=None)Clear cached functions for specified backends.
Mirrors CodeGenerator.reset_cache().
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| backends | List[Backend] | Backends to reset (None = all) | None |
Examples
>>> # Clear all caches
>>> handler.reset_cache()
>>>
>>> # Clear only torch (e.g., after device change)
>>> handler.reset_cache(['torch'])reset_stats
systems.base.utils.stochastic.DiffusionHandler.reset_stats()Reset generation statistics.
Examples
>>> handler.reset_stats()
>>> stats = handler.get_stats()
>>> assert stats['generations'] == 0warmup
systems.base.utils.stochastic.DiffusionHandler.warmup(backend, **kwargs)Warm up function compilation (especially useful for JAX JIT).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| backend | Backend | Backend to warm up | required |
| **kwargs | Test inputs for warmup call | {} |
Examples
>>> # JAX JIT warmup
>>> handler.warmup('jax')
>>> # First call compiled, subsequent calls fast