Source code for jaxsim.rbda.collidable_points

import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Skew

from . import utils


[docs] def collidable_points_pos_vel( model: js.model.JaxSimModel, *, base_position: jtp.Vector, base_quaternion: jtp.Vector, joint_positions: jtp.Vector, base_linear_velocity: jtp.Vector, base_angular_velocity: jtp.Vector, joint_velocities: jtp.Vector, ) -> tuple[jtp.Matrix, jtp.Matrix]: """ Compute the position and linear velocity of collidable points in the world frame. Args: model: The model to consider. base_position: The position of the base link. base_quaternion: The quaternion of the base link. joint_positions: The positions of the joints. base_linear_velocity: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. joint_velocities: The velocities of the joints. Returns: A tuple containing the position and linear velocity of collidable points. """ if len(model.kin_dyn_parameters.contact_parameters.body) == 0: return jnp.array(0).astype(float), jnp.empty(0).astype(float) W_p_B, W_Q_B, s, W_v_WB, , _, _, _, _, _ = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, joint_velocities=joint_velocities, ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the base transform. W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B), translation=W_p_B, ) # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( joint_positions=s, base_transform=W_H_B.as_matrix() ) # Allocate buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0])) # Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity. W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6)) W_v_Wi = W_v_Wi.at[0].set(W_v_WB) # ==================== # Propagate kinematics # ==================== PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix] propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi) def propagate_kinematics( carry: PropagateTransformsCarry, i: jtp.Int ) -> tuple[PropagateTransformsCarry, None]: ii = i - 1 W_X_i, W_v_Wi = carry # Compute the parent to child 6D transform. λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i]) # Compute the world to child 6D transform. W_Xi_i = W_X_i[λ[i]] @ λi_X_i W_X_i = W_X_i.at[i].set(W_Xi_i) # Propagate the 6D velocity. W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * [ii]).squeeze() W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi) return (W_X_i, W_v_Wi), None (W_X_i, W_v_Wi), _ = ( jax.lax.scan( f=propagate_kinematics, init=propagate_transforms_carry, xs=jnp.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(W_X_i, W_v_Wi), None] ) # ================================================== # Compute position and velocity of collidable points # ================================================== def process_point_kinematics( Li_p_C: jtp.Vector, parent_body: jtp.Int ) -> tuple[jtp.Vector, jtp.Vector]: # Compute the position of the collidable point. W_p_Ci = ( Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1]) )[0:3] # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}. CW_vl_WCi = ( jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()]) @ W_v_Wi[parent_body].squeeze() ) return W_p_Ci, CW_vl_WCi # Process all the collidable points in parallel. W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( model.kin_dyn_parameters.contact_parameters.point, jnp.array(model.kin_dyn_parameters.contact_parameters.body), ) return W_p_Ci, CW_vl_WC