Source code for jaxsim.api.common

import abc
import contextlib
import dataclasses
import enum
import functools
from collections.abc import Iterator

import jax
import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.utils import JaxsimDataclass, Mutability

try:
    from typing import Self
except ImportError:
    from typing_extensions import Self


[docs] @enum.unique class VelRepr(enum.IntEnum): """ Enumeration of all supported 6D velocity representations. """ Body = enum.auto() Mixed = enum.auto() Inertial = enum.auto()
[docs] @jax_dataclasses.pytree_dataclass class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): """ Base class for model data structures with velocity representation. """ velocity_representation: Static[VelRepr] = dataclasses.field( default=VelRepr.Inertial, kw_only=True )
[docs] @contextlib.contextmanager def switch_velocity_representation( self, velocity_representation: VelRepr ) -> Iterator[Self]: """ Context manager to temporarily switch the velocity representation. Args: velocity_representation: The new velocity representation. Yields: The same object with the new velocity representation. """ original_representation = self.velocity_representation try: # First, we replace the velocity representation. with self.mutable_context( mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=True, ): self.velocity_representation = velocity_representation # Then, we yield the data with changed representation. # We run this in a mutable context with restoration so that any exception # occurring, we restore the original object in case it was modified. with self.mutable_context( mutability=self.mutability(), restore_after_exception=True ): yield self finally: with self.mutable_context( mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=True, ): self.velocity_representation = original_representation
[docs] @staticmethod @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) def inertial_to_other_representation( array: jtp.Array, other_representation: VelRepr, transform: jtp.Matrix, *, is_force: bool, ) -> jtp.Array: r""" Convert a 6D quantity from inertial-fixed to another representation. Args: array: The 6D quantity to convert. other_representation: The representation to convert to. transform: The :math:`W \mathbf{H}_O` transform, where :math:`O` is the reference frame of the other representation. is_force: Whether the quantity is a 6D force or a 6D velocity. Returns: The 6D quantity in the other representation. """ W_array = array.squeeze() W_H_O = transform.squeeze() if W_array.size != 6: raise ValueError(W_array.size, 6) if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) match other_representation: case VelRepr.Inertial: return W_array case VelRepr.Body: if not is_force: O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) O_array = O_Xv_W @ W_array else: O_Xf_W = Adjoint.from_transform(transform=W_H_O).T O_array = O_Xf_W @ W_array return O_array case VelRepr.Mixed: W_p_O = W_H_O[0:3, 3] W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) OW_array = OW_Xv_W @ W_array else: OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T OW_array = OW_Xf_W @ W_array return OW_array case _: raise ValueError(other_representation)
[docs] @staticmethod @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) def other_representation_to_inertial( array: jtp.Array, other_representation: VelRepr, transform: jtp.Matrix, *, is_force: bool, ) -> jtp.Array: r""" Convert a 6D quantity from another representation to inertial-fixed. Args: array: The 6D quantity to convert. other_representation: The representation to convert from. transform: The `math:W \mathbf{H}_O` transform, where `math:O` is the reference frame of the other representation. is_force: Whether the quantity is a 6D force or a 6D velocity. Returns: The 6D quantity in the inertial-fixed representation. """ W_array = array.squeeze() W_H_O = transform.squeeze() if W_array.size != 6: raise ValueError(W_array.size, 6) if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) match other_representation: case VelRepr.Inertial: W_array = array return W_array case VelRepr.Body: O_array = array if not is_force: W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O) W_array = W_Xv_O @ O_array else: W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T W_array = W_Xf_O @ O_array return W_array case VelRepr.Mixed: BW_array = array W_p_O = W_H_O[0:3, 3] W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW) W_array = W_Xv_BW @ BW_array else: W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T W_array = W_Xf_BW @ BW_array return W_array case _: raise ValueError(other_representation)