Source code for jaxsim.integrators.common

import abc
import dataclasses
from typing import Any, ClassVar, Generic, Protocol, TypeVar

import jax
import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability

try:
    from typing import override
except ImportError:
    from typing_extensions import override

try:
    from typing import Self
except ImportError:
    from typing_extensions import Self


# =============
# Generic types
# =============

Time = jtp.FloatLike
TimeStep = jtp.FloatLike
State = NextState = TypeVar("State")
StateDerivative = TypeVar("StateDerivative")
PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)


[docs] class SystemDynamics(Protocol[State, StateDerivative]): def __call__( self, x: State, t: Time, **kwargs ) -> tuple[StateDerivative, dict[str, Any]]: ...
# ======================= # Base integrator classes # =======================
[docs] @jax_dataclasses.pytree_dataclass class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]): AfterInitKey: ClassVar[str] = "after_init" InitializingKey: ClassVar[str] = "initializing" AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics" dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field( repr=False, hash=False, compare=False, kw_only=True ) params: dict[str, Any] = dataclasses.field( default_factory=dict, repr=False, hash=False, compare=False, kw_only=True )
[docs] @classmethod def build( cls: type[Self], *, dynamics: SystemDynamics[State, StateDerivative], **kwargs, ) -> Self: """ Build the integrator object. Args: dynamics: The system dynamics. **kwargs: Additional keyword arguments to build the integrator. Returns: The integrator object. """ return cls(dynamics=dynamics, **kwargs)
[docs] def step( self, x0: State, t0: Time, dt: TimeStep, *, params: dict[str, Any], **kwargs, ) -> tuple[State, dict[str, Any]]: """ Perform a single integration step. Args: x0: The initial state of the system. t0: The initial time of the system. dt: The time step of the integration. params: The auxiliary dictionary of the integrator. **kwargs: Additional keyword arguments. Returns: The final state of the system and the updated auxiliary dictionary. """ with self.editable(validate=False) as integrator: integrator.params = params with integrator.mutable_context(mutability=Mutability.MUTABLE): xf, aux_dict = integrator(x0, t0, dt, **kwargs) return ( xf, integrator.params | {Integrator.AfterInitKey: jnp.array(False).astype(bool)} | aux_dict, )
@abc.abstractmethod def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState: pass
[docs] def init( self, x0: State, t0: Time, dt: TimeStep, *, include_dynamics_aux_dict: bool = False, **kwargs, ) -> dict[str, Any]: """ Initialize the integrator. Args: x0: The initial state of the system. t0: The initial time of the system. dt: The time step of the integration. Returns: The auxiliary dictionary of the integrator. Note: This method should have the same signature as the inherited `__call__` method, including additional kwargs. Note: If the integrator supports FSAL, the pair `(x0, t0)` must match the real initial state and time of the system, otherwise the initial derivative of the first step will be wrong. """ with self.editable(validate=False) as integrator: # Initialize the integrator parameters. # For initialization purpose, the integrators can check if the # `Integrator.InitializingKey` is present in their parameters. # The AfterInitKey is used in the first step after initialization. integrator.params = { Integrator.InitializingKey: jnp.array(True), Integrator.AfterInitKey: jnp.array(False), } # Run a dummy call of the integrator. # It is used only to get the params so that we know the structure # of the corresponding pytree. _ = integrator(x0, t0, dt, **kwargs) # Remove the injected key. _ = integrator.params.pop(Integrator.InitializingKey) # Make sure that all leafs of the dictionary are JAX arrays. # Also, since these are dummy parameters, set them all to zero. params_after_init = jax.tree.map(lambda l: jnp.zeros_like(l), integrator.params) # Mark the next step as first step after initialization. params_after_init = params_after_init | { Integrator.AfterInitKey: jnp.array(True) } # Store the zero parameters in the integrator. # When the integrator is stepped, this is used to check if the passed # parameters are valid. with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): self.params = params_after_init return params_after_init
[docs] @jax_dataclasses.pytree_dataclass class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]): # The Runge-Kutta matrix. A: ClassVar[jtp.Matrix] # The weights coefficients. # Note that in practice we typically use its transpose `b.transpose()`. b: ClassVar[jtp.Matrix] # The nodes coefficients. c: ClassVar[jtp.Vector] # Define the order of the solution. # It should have as many elements as the number of rows of `b.transpose()`. order_of_bT_rows: ClassVar[tuple[int, ...]] # Define the row of the integration output corresponding to the final solution. # This is the row of b.T that produces the final state. row_index_of_solution: ClassVar[int] # Attributes of FSAL (first-same-as-last) property. fsal_enabled_if_supported: Static[bool] = dataclasses.field(repr=False) index_of_fsal: Static[jtp.IntLike | None] = dataclasses.field(repr=False) @property def has_fsal(self) -> bool: return self.fsal_enabled_if_supported and self.index_of_fsal is not None @property def order(self) -> int: return self.order_of_bT_rows[self.row_index_of_solution]
[docs] @override @classmethod def build( cls: type[Self], *, dynamics: SystemDynamics[State, StateDerivative], fsal_enabled_if_supported: jtp.BoolLike = True, **kwargs, ) -> Self: """ Build the integrator object. Args: dynamics: The system dynamics. fsal_enabled_if_supported: Whether to enable the FSAL property, if supported. **kwargs: Additional keyword arguments to build the integrator. Returns: The integrator object. """ # Check validity of the Butcher tableau. if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c): raise ValueError("The Butcher tableau of this class is not valid.") # Check that b.T has enough rows based on the configured index of the solution. if cls.row_index_of_solution >= cls.b.T.shape[0]: msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})." raise ValueError(msg.format(cls.row_index_of_solution, cls.b.T.shape[0])) # Check that the tuple containing the order of the b.T rows matches the number # of the b.T rows. if len(cls.order_of_bT_rows) != cls.b.T.shape[0]: msg = "Wrong size of 'order_of_bT_rows' ({}), should be {}." raise ValueError(msg.format(len(cls.order_of_bT_rows), cls.b.T.shape[0])) # Check if the Butcher tableau supports FSAL (first-same-as-last). # If it does, store the index of the intermediate derivative to be used as the # first derivative of the next iteration. has_fsal, index_of_fsal = ( # noqa: F841 ExplicitRungeKutta.butcher_tableau_supports_fsal( A=cls.A, b=cls.b, c=cls.c, index_of_solution=cls.row_index_of_solution ) ) # Build the integrator object. integrator = super().build( dynamics=dynamics, index_of_fsal=index_of_fsal, fsal_enabled_if_supported=bool(fsal_enabled_if_supported), **kwargs, ) return integrator
def __call__( self, x0: State, t0: Time, dt: TimeStep, **kwargs ) -> tuple[NextState, dict[str, Any]]: # Here z is a batched state with as many batch elements as b.T rows. # Note that z has multiple batches only if b.T has more than one row, # e.g. in Butcher tableau of embedded schemes. z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs) # The next state is the batch element located at the configured index of solution. next_state = jax.tree.map(lambda l: l[self.row_index_of_solution], z) return next_state, aux_dict
[docs] @classmethod def integrate_rk_stage( cls, x0: State, t0: Time, dt: TimeStep, k: StateDerivative ) -> NextState: """ Integrate a single stage of the Runge-Kutta method. Args: x0: The initial state of the system. t0: The initial time of the system. dt: The time step of the RK integration scheme. Note that this is not the stage timestep, as it depends on the `A` matrix used to compute the `k` argument. k: The RK state derivative of the current stage, weighted with the `A` matrix. Returns: The state at the next stage of the integration. Note: In the most generic case, `k` could be an arbitrary composition of the kᵢ derivatives, depending on the RK matrix A. Note: Overriding this method allows users to use different classes defining `State` and `StateDerivative`. Be aware that the timestep `dt` is not the stage timestep, therefore the map used to convert the state derivative must be time-independent. """ op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf return jax.tree.map(op, x0, k)
[docs] @classmethod def post_process_state( cls, x0: State, t0: Time, xf: NextState, dt: TimeStep ) -> NextState: r""" Post-process the integrated state at :math:`t_f = t_0 + \Delta t`. Args: x0: The initial state of the system. t0: The initial time of the system. xf: The final state of the system obtain through the integration. dt: The time step used for the integration. Returns: The post-processed integrated state. """ return xf
def _compute_next_state( self, x0: State, t0: Time, dt: TimeStep, **kwargs ) -> tuple[NextState, dict[str, Any]]: """ Compute the next state of the system, returning all the output states. Args: x0: The initial state of the system. t0: The initial time of the system. dt: The time step of the integration. **kwargs: Additional keyword arguments. Returns: A batched state with as many batch elements as `b.T` rows. """ # Call variables with better symbols. Δt = dt c = self.c b = self.b A = self.A # Close f over optional kwargs. f = lambda x, t: self.dynamics(x=x, t=t, **kwargs) # Initialize the carry of the for loop with the stacked kᵢ vectors. carry0 = jax.tree.map( lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0), x0, ) # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration. get_ẋ0_and_aux_dict = lambda: self.params.get("dxdt0", f(x0, t0)) # We use a `jax.lax.scan` to compile the `f` function only once. # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code # would include 4 repetitions of the `f` logic, making everything extremely slow. def scan_body( carry: jax.Array, i: int | jax.Array ) -> tuple[jax.Array, dict[str, Any]]: """""" # Unpack the carry, i.e. the stacked kᵢ vectors. K = carry # Define the computation of the Runge-Kutta stage. def compute_ki() -> tuple[jax.Array, dict[str, Any]]: # Compute ∑ⱼ aᵢⱼ kⱼ. op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k) sum_ak = jax.tree.map(op_sum_ak, K) # Compute the next state for the kᵢ evaluation. # Note that this is not a Δt integration since aᵢⱼ could be fractional. xi = self.integrate_rk_stage(x0, t0, Δt, sum_ak) # Compute the next time for the kᵢ evaluation. ti = t0 + c[i] * Δt # This is kᵢ, aux_dict = f(xᵢ, tᵢ). return f(xi, ti) # This selector enables FSAL property in the first iteration (i=0). ki, aux_dict = jax.lax.cond( pred=jnp.logical_and(i == 0, self.has_fsal), true_fun=get_ẋ0_and_aux_dict, false_fun=compute_ki, ) # Store the kᵢ derivative in K. op = lambda l_k, l_ki: l_k.at[i].set(l_ki) K = jax.tree.map(op, K, ki) carry = K return carry, aux_dict # Compute the state derivatives kᵢ. K, aux_dict = jax.lax.scan( f=scan_body, init=carry0, xs=jnp.arange(c.size), ) # Update the FSAL property for the next iteration. if self.has_fsal: self.params["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K) # Compute the output state. # Note that z contains as many new states as the rows of `b.T`. op = lambda x0, k: x0 + Δt * jnp.einsum("zs,s...->z...", b.T, k) z = jax.tree.map(op, x0, K) # Transform the final state of the integration. # This allows to inject custom logic, if needed. z_transformed = jax.vmap( lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt) )(z) return z_transformed, aux_dict
[docs] @staticmethod def butcher_tableau_is_valid( A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector ) -> jtp.Bool: """ Check if the Butcher tableau is valid. Args: A: The Runge-Kutta matrix. b: The weights coefficients. c: The nodes coefficients. Returns: `True` if the Butcher tableau is valid, `False` otherwise. """ valid = True valid = valid and A.ndim == 2 valid = valid and b.ndim == 2 valid = valid and c.ndim == 1 valid = valid and b.T.shape[0] <= 2 valid = valid and A.shape[0] == A.shape[1] valid = valid and A.shape == (c.size, b.T.shape[1]) valid = valid and bool(jnp.all(b.T.sum(axis=1) == 1)) return valid
[docs] @staticmethod def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool: """ Check if the Butcher tableau corresponds to an explicit integration scheme. Args: A: The Runge-Kutta matrix. Returns: `True` if the Butcher tableau is explicit, `False` otherwise. """ return jnp.allclose(A, jnp.tril(A, k=-1))
[docs] @staticmethod def butcher_tableau_supports_fsal( A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector, index_of_solution: jtp.IntLike = 0, ) -> tuple[bool, int | None]: """ Check if the Butcher tableau supports the FSAL (first-same-as-last) property. Args: A: The Runge-Kutta matrix. b: The weights coefficients. c: The nodes coefficients. index_of_solution: The index of the row of `b.T` corresponding to the solution. Returns: A tuple containing a boolean indicating whether the Butcher tableau supports FSAL, and the index i of the intermediate kᵢ derivative corresponding to the initial derivative `f(x0, t0)` of the next step. """ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c): raise ValueError("The Butcher tableau is not valid.") if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A): return False if index_of_solution >= b.T.shape[0]: msg = "The index of the solution (i-th row of `b.T`) is out of range." raise ValueError(msg) if c[0] != 0: return False, None # Find all the rows of A where c = 1 (therefore at t=tf). The Butcher tableau # supports FSAL if any of these rows (there might be more rows with c=1) matches # the rows of b.T corresponding to the next state (marked by `index_of_solution`). # This last condition means that the last kᵢ derivative is computed at (tf, xf), # that corresponds to the (t0, x0) pair of the next integration call. rows_of_A_with_fsal = (A == b.T[None, index_of_solution]).all(axis=1) rows_of_A_with_fsal = jnp.logical_and(rows_of_A_with_fsal, (c == 1)) # If there is no match, it means that the Butcher tableau does not support FSAL. if not rows_of_A_with_fsal.any(): return False, None # Return the index of the row of A providing the fsal derivative (that is the # possibly intermediate kᵢ derivative). # Note that if multiple rows match (it should not), we return the first match. return True, int(jnp.where(rows_of_A_with_fsal)[0].tolist()[0])
[docs] class ExplicitRungeKuttaSO3Mixin: """ Mixin class to apply over explicit RK integrators defined on `PyTreeType = ODEState` to integrate the quaternion on SO(3). """ @classmethod def post_process_state( cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep ) -> js.ode_data.ODEState: # Extract the initial base quaternion. W_Q_B_t0 = x0.physics_model.base_quaternion # We assume that the initial quaternion is already unary. exceptions.raise_runtime_error_if( condition=jnp.logical_not(jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0)), msg="The SO(3) integrator received a quaternion at t0 that is not unary.", ) # Get the angular velocity ω to integrate the quaternion. # This velocity ω[t0] is computed in the previous timestep by averaging the kᵢ # corresponding to the active RK-based scheme. Therefore, by using the ω[t0], # we obtain an explicit RK scheme operating on the SO(3) manifold. # Note that the current integrator is not a semi-implicit scheme, therefore # using the final ω[tf] would be not correct. W_ω_WB_t0 = x0.physics_model.base_angular_velocity # Integrate the quaternion on SO(3). W_Q_B_tf = jaxsim.math.Quaternion.integration( quaternion=W_Q_B_t0, dt=dt, omega=W_ω_WB_t0, omega_in_body_fixed=False, ) # Replace the quaternion in the final state. return xf.replace( physics_model=xf.physics_model.replace(base_quaternion=W_Q_B_tf), validate=True, )