Integrators#
Common#
- class jaxsim.integrators.common.ExplicitRungeKutta(fsal_enabled_if_supported, index_of_fsal, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- classmethod build(*, dynamics, fsal_enabled_if_supported=True, **kwargs)[source]#
Build the integrator object.
- Parameters:
dynamics (
SystemDynamics
[TypeVar
(State
),TypeVar
(StateDerivative
)]) – The system dynamics.fsal_enabled_if_supported (
Union
[bool
,Array
,ndarray
,bool
,number
,int
,float
,complex
]) – Whether to enable the FSAL property, if supported.**kwargs – Additional keyword arguments to build the integrator.
- Return type:
Self
- Returns:
The integrator object.
- static butcher_tableau_is_explicit(A)[source]#
Check if the Butcher tableau corresponds to an explicit integration scheme.
- Parameters:
A (
Array
) – The Runge-Kutta matrix.- Return type:
Array
- Returns:
True if the Butcher tableau is explicit, False otherwise.
- static butcher_tableau_is_valid(A, b, c)[source]#
Check if the Butcher tableau is valid.
- Parameters:
A (
Array
) – The Runge-Kutta matrix.b (
Array
) – The weights coefficients.c (
Array
) – The nodes coefficients.
- Return type:
Array
- Returns:
True if the Butcher tableau is valid, False otherwise.
- static butcher_tableau_supports_fsal(A, b, c, index_of_solution=0)[source]#
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
- Parameters:
A (
Array
) – The Runge-Kutta matrix.b (
Array
) – The weights coefficients.c (
Array
) – The nodes coefficients.index_of_solution (
Union
[int
,Array
,ndarray
,bool
,number
,bool
,float
,complex
]) – The index of the row of b.T corresponding to the solution.
- Return type:
tuple
[bool
,int
|None
]- Returns:
A tuple containing a boolean indicating whether the Butcher tableau supports FSAL, and the index i of the intermediate kᵢ derivative corresponding to the initial derivative f(x0, t0) of the next step.
- classmethod integrate_rk_stage(x0, t0, dt, k)[source]#
Integrate a single stage of the Runge-Kutta method.
- Parameters:
x0 (
TypeVar
(State
)) – The initial state of the system.t0 (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The initial time of the system.dt (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The time step of the RK integration scheme. Note that this is not the stage timestep, as it depends on the A matrix used to compute the k argument.k (
TypeVar
(StateDerivative
)) – The RK state derivative of the current stage, weighted with the A matrix.
- Return type:
TypeVar
(State
)- Returns:
The state at the next stage of the integration.
Note
In the most generic case, k could be an arbitrary composition of the kᵢ derivatives, depending on the RK matrix A.
Note
Overriding this method allows users to use different classes defining State and StateDerivative. Be aware that the timestep dt is not the stage timestep, therefore the map used to convert the state derivative must be time-independent.
- classmethod post_process_state(x0, t0, xf, dt)[source]#
Post-process the integrated state at \(t_f = t_0 + \Delta t\).
- Parameters:
x0 (
TypeVar
(State
)) – The initial state of the system.t0 (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The initial time of the system.xf (
TypeVar
(State
)) – The final state of the system obtain through the integration.dt (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The time step used for the integration.
- Return type:
TypeVar
(State
)- Returns:
The post-processed integrated state.
- class jaxsim.integrators.common.ExplicitRungeKuttaSO3Mixin[source]#
Mixin class to apply over explicit RK integrators defined on PyTreeType = ODEState to integrate the quaternion on SO(3).
- class jaxsim.integrators.common.Integrator(*, dynamics, params=<factory>)[source]#
- Parameters:
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- classmethod build(*, dynamics, **kwargs)[source]#
Build the integrator object.
- Parameters:
dynamics (
SystemDynamics
[TypeVar
(State
),TypeVar
(StateDerivative
)]) – The system dynamics.**kwargs – Additional keyword arguments to build the integrator.
- Return type:
Self
- Returns:
The integrator object.
- init(x0, t0, dt, *, include_dynamics_aux_dict=False, **kwargs)[source]#
Initialize the integrator.
- Parameters:
x0 (
TypeVar
(State
)) – The initial state of the system.t0 (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The initial time of the system.dt (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The time step of the integration.include_dynamics_aux_dict (bool)
- Return type:
dict
[str
,Any
]- Returns:
The auxiliary dictionary of the integrator.
Note
This method should have the same signature as the inherited __call__ method, including additional kwargs.
Note
If the integrator supports FSAL, the pair (x0, t0) must match the real initial state and time of the system, otherwise the initial derivative of the first step will be wrong.
- step(x0, t0, dt, *, params, **kwargs)[source]#
Perform a single integration step.
- Parameters:
x0 (
TypeVar
(State
)) – The initial state of the system.t0 (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The initial time of the system.dt (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The time step of the integration.params (
dict
[str
,Any
]) – The auxiliary dictionary of the integrator.**kwargs – Additional keyword arguments.
- Return type:
tuple
[TypeVar
(State
),dict
[str
,Any
]]- Returns:
The final state of the system and the updated auxiliary dictionary.
Fixed Step#
- class jaxsim.integrators.fixed_step.ForwardEuler(fsal_enabled_if_supported, index_of_fsal, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- class jaxsim.integrators.fixed_step.ForwardEulerSO3(fsal_enabled_if_supported, index_of_fsal, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- class jaxsim.integrators.fixed_step.Heun2(fsal_enabled_if_supported, index_of_fsal, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- class jaxsim.integrators.fixed_step.Heun2SO3(fsal_enabled_if_supported, index_of_fsal, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- class jaxsim.integrators.fixed_step.RungeKutta4(fsal_enabled_if_supported, index_of_fsal, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- class jaxsim.integrators.fixed_step.RungeKutta4SO3(fsal_enabled_if_supported, index_of_fsal, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
Variable Step#
- class jaxsim.integrators.variable_step.BogackiShampineSO3(fsal_enabled_if_supported, index_of_fsal, dt_max=inf, dt_min=-inf, rtol=0.0001, atol=1e-05, safety=0.9, beta_max=2.5, beta_min=0.1, max_step_rejections=5, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dt_max (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
dt_min (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
rtol (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
atol (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
safety (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
beta_max (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
beta_min (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
max_step_rejections (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- class jaxsim.integrators.variable_step.EmbeddedRungeKutta(fsal_enabled_if_supported, index_of_fsal, dt_max=inf, dt_min=-inf, rtol=0.0001, atol=1e-05, safety=0.9, beta_max=2.5, beta_min=0.1, max_step_rejections=5, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dt_max (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
dt_min (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
rtol (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
atol (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
safety (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
beta_max (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
beta_min (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
max_step_rejections (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- classmethod build(*, dynamics, fsal_enabled_if_supported=True, dt_max=inf, dt_min=-inf, rtol=0.0001, atol=1e-05, safety=0.9, beta_max=2.5, beta_min=0.1, max_step_rejections=5, **kwargs)[source]#
Build the integrator object.
- Parameters:
dynamics (
SystemDynamics
[TypeVar
(State
),TypeVar
(StateDerivative
)]) – The system dynamics.fsal_enabled_if_supported (
Union
[bool
,Array
,ndarray
,bool
,number
,int
,float
,complex
]) – Whether to enable the FSAL property, if supported.**kwargs – Additional keyword arguments to build the integrator.
dt_max (float | Array | ndarray | bool | number | bool | int | complex)
dt_min (float | Array | ndarray | bool | number | bool | int | complex)
rtol (float | Array | ndarray | bool | number | bool | int | complex)
atol (float | Array | ndarray | bool | number | bool | int | complex)
safety (float | Array | ndarray | bool | number | bool | int | complex)
beta_max (float | Array | ndarray | bool | number | bool | int | complex)
beta_min (float | Array | ndarray | bool | number | bool | int | complex)
max_step_rejections (int | Array | ndarray | bool | number | bool | float | complex)
- Return type:
Self
- Returns:
The integrator object.
- init(x0, t0, dt=None, *, include_dynamics_aux_dict=False, **kwargs)[source]#
Initialize the integrator.
- Parameters:
x0 (
TypeVar
(State
)) – The initial state of the system.t0 (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The initial time of the system.dt (
Union
[int
,Array
,ndarray
,bool
,number
,bool
,float
,complex
,None
]) – The time step of the integration.include_dynamics_aux_dict (bool)
- Return type:
dict
[str
,Any
]- Returns:
The auxiliary dictionary of the integrator.
Note
This method should have the same signature as the inherited __call__ method, including additional kwargs.
Note
If the integrator supports FSAL, the pair (x0, t0) must match the real initial state and time of the system, otherwise the initial derivative of the first step will be wrong.
- class jaxsim.integrators.variable_step.HeunEulerSO3(fsal_enabled_if_supported, index_of_fsal, dt_max=inf, dt_min=-inf, rtol=0.0001, atol=1e-05, safety=0.9, beta_max=2.5, beta_min=0.1, max_step_rejections=5, *, dynamics, params=<factory>)[source]#
- Parameters:
fsal_enabled_if_supported (Annotated[bool, '__jax_dataclasses_static_field__'])
index_of_fsal (Annotated[int | Array | ndarray | bool | number | bool | float | complex | None, '__jax_dataclasses_static_field__'])
dt_max (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
dt_min (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
rtol (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
atol (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
safety (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
beta_max (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
beta_min (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
max_step_rejections (Annotated[float | Array | ndarray | bool | number | bool | int | complex, '__jax_dataclasses_static_field__'])
dynamics (Annotated[SystemDynamics[State, StateDerivative], '__jax_dataclasses_static_field__'])
params (dict[str, Any])
- jaxsim.integrators.variable_step.compute_pytree_scale(x1, x2=None, rtol=0.0001, atol=1e-05)[source]#
Compute the component-wise state scale factors to scale dynamical states.
- Parameters:
x1 (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The first state (often the initial state).x2 (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The optional second state (often the final state).rtol (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The relative tolerance to scale the state.atol (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The absolute tolerance to scale the state.
- Return type:
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
- Returns:
A pytree with the same structure of the state containing the scaling factors.
- jaxsim.integrators.variable_step.estimate_step_size(x0, t0, f, order, rtol=0.0001, atol=1e-05)[source]#
Compute the initial step size to warm-start variable-step integrators.
- Parameters:
x0 (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The initial state.t0 (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The initial time.f (
SystemDynamics
) – The state derivative function \(f(x, t)\).order (
Union
[int
,Array
,ndarray
,bool
,number
,bool
,float
,complex
]) – The order \(p\) of an integrator with truncation error \(\mathcal{O}(\Delta t^{p+1})\).rtol (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The relative tolerance to scale the state.atol (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The absolute tolerance to scale the state.
- Return type:
tuple
[Array
,dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
]- Returns:
A tuple containing the computed initial step size and the state derivative \(\dot{x} = f(x_0, t_0)\).
Note
Interested readers could find implementation details in:
Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4. E. Hairer, S. P. Norsett G. Wanner.
- jaxsim.integrators.variable_step.local_error_estimation(xf, xf_estimate=None, x0=None, rtol=0.0001, atol=1e-05, norm_ord=inf)[source]#
Estimate the local integration error, often used in Embedded RK schemes.
- Parameters:
xf (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The final state, often computed with the most accurate integrator.xf_estimate (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The estimated final state, often computed with the less accurate integrator. If missing, it is initialized to zero.x0 (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The initial state to compute the scaling factors. If missing, it is initialized to zero.rtol (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The relative tolerance to scale the state.atol (
Union
[float
,Array
,ndarray
,bool
,number
,bool
,int
,complex
]) – The absolute tolerance to scale the state.norm_ord (
Union
[int
,Array
,ndarray
,bool
,number
,bool
,float
,complex
]) – The norm to use to compute the error. Default is the infinity norm.
- Return type:
Array
- Returns:
The estimated local integration error.