from __future__ import annotations
import jax.numpy as jnp
import jax_dataclasses
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.rbda import ContactsState
from jaxsim.rbda.contacts.relaxed_rigid import (
RelaxedRigidContacts,
RelaxedRigidContactsState,
)
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
from jaxsim.utils import JaxsimDataclass
# =============================================================================
# Define the input and state of the ODE system defining the integrated dynamics
# =============================================================================
# Note: the ODE system is the combination of the floating-base dynamics and the
# soft-contacts dynamics.
[docs]
@jax_dataclasses.pytree_dataclass
class ODEState(JaxsimDataclass):
"""
The state of the ODE system.
Attributes:
physics_model: The state of the physics model.
contact: The state of the contacts model.
"""
physics_model: PhysicsModelState
contact: ContactsState
[docs]
@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
joint_positions: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
**kwargs,
) -> ODEState:
"""
Build an `ODEState` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the ODE state.
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
kwargs: Additional arguments needed to build the contact state.
Returns:
The `ODEState` built from the `JaxSimModel`.
Note:
If any of the state components are not provided, they are built from the
`JaxSimModel` and initialized to zero.
"""
# Get the contact model from the `JaxSimModel`.
match model.contact_model:
case SoftContacts():
tangential_deformation = kwargs.get("tangential_deformation", None)
contact = SoftContactsState.build_from_jaxsim_model(
model=model,
**(
dict(tangential_deformation=tangential_deformation)
if tangential_deformation is not None
else dict()
),
)
case RigidContacts():
contact = RigidContactsState.build()
case RelaxedRigidContacts():
contact = RelaxedRigidContactsState.build()
case _:
raise ValueError("Unable to determine contact state class prefix.")
return ODEState.build(
model=model,
physics_model_state=PhysicsModelState.build_from_jaxsim_model(
model=model,
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_position=base_position,
base_quaternion=base_quaternion,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
),
contact=contact,
)
[docs]
@staticmethod
def build(
physics_model_state: PhysicsModelState | None = None,
contact: ContactsState | None = None,
model: js.model.JaxSimModel | None = None,
) -> ODEState:
"""
Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
Args:
physics_model_state: The state of the physics model.
contact: The state of the contacts model.
model: The `JaxSimModel` associated with the ODE state.
Returns:
A `ODEState` instance.
"""
physics_model_state = (
physics_model_state
if physics_model_state is not None
else PhysicsModelState.zero(model=model)
)
# Get the contact model from the `JaxSimModel`.
match contact:
case (
SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
):
pass
case None:
contact = SoftContactsState.zero(model=model)
case _:
raise ValueError("Unable to determine contact state class prefix.")
return ODEState(physics_model=physics_model_state, contact=contact)
[docs]
@staticmethod
def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
"""
Build a zero `ODEState` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the ODE state.
Returns:
A zero `ODEState` instance.
"""
model_state = ODEState.build(
model=model, contact=data.state.contact.zero(model=model)
)
return model_state
[docs]
def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `ODEState` is valid for a given `JaxSimModel`.
Args:
model: The `JaxSimModel` to validate the `ODEState` against.
Returns:
`True` if the ODE state is valid for the given model, `False` otherwise.
"""
return self.physics_model.valid(model=model) and self.contact.valid(model=model)
# ==================================================
# Define the input and state of floating-base robots
# ==================================================
[docs]
@jax_dataclasses.pytree_dataclass
class PhysicsModelState(JaxsimDataclass):
"""
Class storing the state of the physics model dynamics.
Attributes:
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
"""
# Joint state
joint_positions: jtp.Vector
joint_velocities: jtp.Vector
# Base state
base_position: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.zeros(3)
)
base_quaternion: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.array([1.0, 0, 0, 0])
)
base_linear_velocity: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.zeros(3)
)
base_angular_velocity: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.zeros(3)
)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
HashedNumpyArray.hash_of_array(self.joint_positions),
HashedNumpyArray.hash_of_array(self.joint_velocities),
HashedNumpyArray.hash_of_array(self.base_position),
HashedNumpyArray.hash_of_array(self.base_quaternion),
HashedNumpyArray.hash_of_array(self.base_linear_velocity),
HashedNumpyArray.hash_of_array(self.base_angular_velocity),
)
)
def __eq__(self, other: PhysicsModelState) -> bool:
if not isinstance(other, PhysicsModelState):
return False
return hash(self) == hash(other)
[docs]
@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
joint_positions: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
) -> PhysicsModelState:
"""
Build a `PhysicsModelState` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the state.
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
Note:
If any of the state components are not provided, they are built from the
`JaxSimModel` and initialized to zero.
Returns:
A `PhysicsModelState` instance.
"""
return PhysicsModelState.build(
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_position=base_position,
base_quaternion=base_quaternion,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
number_of_dofs=model.dofs(),
)
[docs]
@staticmethod
def build(
joint_positions: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
number_of_dofs: jtp.Int | None = None,
) -> PhysicsModelState:
"""
Build a `PhysicsModelState`.
Args:
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
number_of_dofs:
The number of degrees of freedom of the physics model.
Returns:
A `PhysicsModelState` instance.
"""
joint_positions = (
joint_positions
if joint_positions is not None
else jnp.zeros(number_of_dofs)
)
joint_velocities = (
joint_velocities
if joint_velocities is not None
else jnp.zeros(number_of_dofs)
)
base_position = base_position if base_position is not None else jnp.zeros(3)
base_quaternion = (
base_quaternion
if base_quaternion is not None
else jnp.array([1.0, 0, 0, 0])
)
base_linear_velocity = (
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
)
base_angular_velocity = (
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
)
physics_model_state = PhysicsModelState(
joint_positions=jnp.array(joint_positions, dtype=float),
joint_velocities=jnp.array(joint_velocities, dtype=float),
base_position=jnp.array(base_position, dtype=float),
base_quaternion=jnp.array(base_quaternion, dtype=float),
base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
)
# TODO (diegoferigo): assert state.valid(physics_model)
return physics_model_state
[docs]
@staticmethod
def zero(model: js.model.JaxSimModel) -> PhysicsModelState:
"""
Build a `PhysicsModelState` with all components initialized to zero.
Args:
model: The `JaxSimModel` associated with the state.
Returns:
A `PhysicsModelState` instance.
"""
return PhysicsModelState.build_from_jaxsim_model(model=model)
[docs]
def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `PhysicsModelState` is valid for a given `JaxSimModel`.
Args:
model: The `JaxSimModel` to validate the `PhysicsModelState` against.
Returns:
`True` if the `PhysicsModelState` is valid for the given model,
`False` otherwise.
"""
shape = self.joint_positions.shape
expected_shape = (model.dofs(),)
if shape != expected_shape:
return False
shape = self.joint_velocities.shape
expected_shape = (model.dofs(),)
if shape != expected_shape:
return False
shape = self.base_position.shape
expected_shape = (3,)
if shape != expected_shape:
return False
shape = self.base_quaternion.shape
expected_shape = (4,)
if shape != expected_shape:
return False
shape = self.base_linear_velocity.shape
expected_shape = (3,)
if shape != expected_shape:
return False
shape = self.base_angular_velocity.shape
expected_shape = (3,)
if shape != expected_shape:
return False
return True