Source code for jaxsim.parsers.descriptions.joint

from __future__ import annotations

import dataclasses
from typing import ClassVar

import jax_dataclasses
import numpy as np

import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass, Mutability

from .link import LinkDescription


[docs] @dataclasses.dataclass(frozen=True) class JointType: Fixed: ClassVar[int] = 0 Revolute: ClassVar[int] = 1 Prismatic: ClassVar[int] = 2
[docs] @jax_dataclasses.pytree_dataclass class JointGenericAxis: """ A joint requiring the specification of a 3D axis. """ # The axis of rotation or translation of the joint (must have norm 1). axis: jtp.Vector def __hash__(self) -> int: return hash(tuple(self.axis.tolist())) def __eq__(self, other: JointGenericAxis) -> bool: if not isinstance(other, JointGenericAxis): return False return hash(self) == hash(other)
[docs] @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class JointDescription(JaxsimDataclass): """ In-memory description of a robot link. Attributes: name (str): The name of the joint. axis (npt.NDArray): The axis of rotation or translation for the joint. pose (npt.NDArray): The pose transformation matrix of the joint. jtype (JointType): The type of the joint. child (LinkDescription): The child link attached to the joint. parent (LinkDescription): The parent link attached to the joint. index (Optional[int]): An optional index for the joint. friction_static (float): The static friction coefficient for the joint. friction_viscous (float): The viscous friction coefficient for the joint. position_limit_damper (float): The damper coefficient for position limits. position_limit_spring (float): The spring coefficient for position limits. position_limit (Tuple[float, float]): The position limits for the joint. initial_position (Union[float, npt.NDArray]): The initial position of the joint. """ name: jax_dataclasses.Static[str] axis: jtp.Vector pose: jtp.Matrix jtype: jax_dataclasses.Static[jtp.IntLike] child: LinkDescription = dataclasses.dataclass(repr=False) parent: LinkDescription = dataclasses.dataclass(repr=False) index: jtp.IntLike | None = None friction_static: jtp.FloatLike = 0.0 friction_viscous: jtp.FloatLike = 0.0 position_limit_damper: jtp.FloatLike = 0.0 position_limit_spring: jtp.FloatLike = 0.0 position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0) initial_position: jtp.FloatLike | jtp.VectorLike = 0.0 motor_inertia: jtp.FloatLike = 0.0 motor_viscous_friction: jtp.FloatLike = 0.0 motor_gear_ratio: jtp.FloatLike = 1.0 def __post_init__(self) -> None: if self.axis is not None: with self.mutable_context( mutability=Mutability.MUTABLE, restore_after_exception=False ): norm_of_axis = np.linalg.norm(self.axis) self.axis = self.axis / norm_of_axis def __eq__(self, other: JointDescription) -> bool: if not isinstance(other, JointDescription): return False if not ( self.name == other.name and self.jtype == other.jtype and self.child == other.child and self.parent == other.parent and self.index == other.index and all( np.allclose(getattr(self, attr), getattr(other, attr)) for attr in [ "axis", "pose", "friction_static", "friction_viscous", "position_limit_damper", "position_limit_spring", "position_limit", "initial_position", "motor_inertia", "motor_viscous_friction", "motor_gear_ratio", ] ), ): return False return True def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( hash(self.name), HashedNumpyArray.hash_of_array(self.axis), HashedNumpyArray.hash_of_array(self.pose), hash(int(self.jtype)), hash(self.child), hash(self.parent), hash(int(self.index)) if self.index is not None else 0, HashedNumpyArray.hash_of_array(self.friction_static), HashedNumpyArray.hash_of_array(self.friction_viscous), HashedNumpyArray.hash_of_array(self.position_limit_damper), HashedNumpyArray.hash_of_array(self.position_limit_spring), HashedNumpyArray.hash_of_array(self.position_limit), HashedNumpyArray.hash_of_array(self.initial_position), HashedNumpyArray.hash_of_array(self.motor_inertia), HashedNumpyArray.hash_of_array(self.motor_viscous_friction), HashedNumpyArray.hash_of_array(self.motor_gear_ratio), ), )