Source code for jaxsim.rbda.contacts.soft

from __future__ import annotations

import dataclasses

import jax
import jax.numpy as jnp
import jax_dataclasses

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Skew, StandardGravity
from jaxsim.terrain import FlatTerrain, Terrain

from .common import ContactModel, ContactsParams, ContactsState


[docs] @jax_dataclasses.pytree_dataclass class SoftContactsParams(ContactsParams): """Parameters of the soft contacts model.""" K: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(1e6, dtype=float) ) D: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(2000, dtype=float) ) mu: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.5, dtype=float) ) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( HashedNumpyArray.hash_of_array(self.K), HashedNumpyArray.hash_of_array(self.D), HashedNumpyArray.hash_of_array(self.mu), ) ) def __eq__(self, other: SoftContactsParams) -> bool: if not isinstance(other, SoftContactsParams): return NotImplemented return hash(self) == hash(other)
[docs] @staticmethod def build( K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5 ) -> SoftContactsParams: """ Create a SoftContactsParams instance with specified parameters. Args: K: The stiffness parameter. D: The damping parameter of the soft contacts model. mu: The static friction coefficient. Returns: A SoftContactsParams instance with the specified parameters. """ return SoftContactsParams( K=jnp.array(K, dtype=float), D=jnp.array(D, dtype=float), mu=jnp.array(mu, dtype=float), )
[docs] @staticmethod def build_default_from_jaxsim_model( model: js.model.JaxSimModel, *, standard_gravity: jtp.FloatLike = StandardGravity, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, ) -> SoftContactsParams: """ Create a SoftContactsParams instance with good default parameters. Args: model: The target model. standard_gravity: The standard gravity constant. static_friction_coefficient: The static friction coefficient between the model and the terrain. max_penetration: The maximum penetration depth. number_of_active_collidable_points_steady_state: The number of contacts supporting the weight of the model in steady state. damping_ratio: The ratio controlling the damping behavior. Returns: A `SoftContactsParams` instance with the specified parameters. Note: The `damping_ratio` parameter allows to operate on the following conditions: - ξ > 1.0: over-damped - ξ = 1.0: critically damped - ξ < 1.0: under-damped """ # Use symbols for input parameters. ξ = damping_ratio δ_max = max_penetration μc = static_friction_coefficient # Compute the total mass of the model. m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() # Rename the standard gravity. g = standard_gravity # Compute the average support force on each collidable point. f_average = m * g / number_of_active_collidable_points_steady_state # Compute the stiffness to get the desired steady-state penetration. K = f_average / jnp.power(δ_max, 3 / 2) # Compute the damping using the damping ratio. critical_damping = 2 * jnp.sqrt(K * m) D = ξ * critical_damping return SoftContactsParams.build(K=K, D=D, mu=μc)
[docs] def valid(self) -> bool: """ Check if the parameters are valid. Returns: `True` if the parameters are valid, `False` otherwise. """ return ( jnp.all(self.K >= 0.0) and jnp.all(self.D >= 0.0) and jnp.all(self.mu >= 0.0) )
[docs] @jax_dataclasses.pytree_dataclass class SoftContacts(ContactModel): """Soft contacts model.""" parameters: SoftContactsParams = dataclasses.field( default_factory=SoftContactsParams ) terrain: jax_dataclasses.Static[Terrain] = dataclasses.field( default_factory=FlatTerrain )
[docs] def compute_contact_forces( self, position: jtp.Vector, velocity: jtp.Vector, tangential_deformation: jtp.Vector, ) -> tuple[jtp.Vector, tuple[jtp.Vector]]: """ Compute the contact forces and material deformation rate. Args: position: The position of the collidable point. velocity: The linear velocity of the collidable point. tangential_deformation: The tangential deformation. Returns: A tuple containing the contact force and material deformation rate. """ # Short name of parameters K = self.parameters.K D = self.parameters.D μ = self.parameters.mu # Material 3D tangential deformation and its derivative m = tangential_deformation.squeeze() = jnp.zeros_like(m) # Note: all the small hardcoded tolerances in this method have been introduced # to allow jax differentiating through this algorithm. They should not affect # the accuracy of the simulation, although they might make it less readable. # ======================== # Normal force computation # ======================== # Unpack the position of the collidable point. px, py, pz = W_p_C = position.squeeze() W_ṗ_C = velocity.squeeze() # Compute the terrain normal and the contact depth. = self.terrain.normal(x=px, y=py).squeeze() h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz]) # Compute the penetration depth normal to the terrain. δ = jnp.maximum(0.0, jnp.dot(h, )) # Compute the penetration normal velocity. δ̇ = -jnp.dot(W_ṗ_C, ) # Non-linear spring-damper model. # This is the force magnitude along the direction normal to the terrain. force_normal_mag = jax.lax.select( pred=δ >= 1e-9, on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇), on_false=jnp.array(0.0), ) # Prevent negative normal forces that might occur when δ̇ is largely negative. force_normal_mag = jnp.maximum(0.0, force_normal_mag) # Compute the 3D linear force in C[W] frame. force_normal = force_normal_mag * # ==================================== # No friction and no tangential forces # ==================================== # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial. # Note: this is equal to the 6D velocities transform: CW_X_W.transpose(). W_Xf_CW = jnp.vstack( [ jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]), jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]), ] ) def with_no_friction(): # Compute 6D mixed force in C[W]. CW_f_lin = force_normal CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)]) # Compute lin-ang 6D forces (inertial representation). W_f = W_Xf_CW @ CW_f return W_f, (,) # ========================= # Compute tangential forces # ========================= def with_friction(): # Initialize the tangential deformation rate ṁ. # For inactive contacts with m≠0, this is the dynamics of the material # relaxation converging exponentially to steady state. = (-K / D) * m # Check if the collidable point is below ground. # Note: when δ=0, we consider the point still not it contact such that # we prevent divisions by 0 in the computations below. active_contact = pz < self.terrain.height(x=px, y=py) def above_terrain(): return jnp.zeros(6), (,) def below_terrain(): # Decompose the velocity in normal and tangential components. v_normal = jnp.dot(W_ṗ_C, ) * v_tangential = W_ṗ_C - v_normal # Compute the tangential force. If inside the friction cone, the contact. f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential) def sticking_contact(): # Sum the normal and tangential forces, and create the 6D force. CW_f_stick = force_normal + f_tangential CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)]) # In this case the 3D material deformation is the tangential velocity. = v_tangential # Return the 6D force in the contact frame and # the deformation derivative. return CW_f, def slipping_contact(): # Project the force to the friction cone boundary. f_tangential_projected = (μ * force_normal_mag) * ( f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9) ) # Sum the normal and tangential forces, and create the 6D force. CW_f_slip = force_normal + f_tangential_projected CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)]) # Correct the material deformation derivative for slipping contacts. # Basically we compute ṁ such that we get `f_tangential` on the cone # given the current (m, δ). ε = 1e-9 δε = jnp.maximum(δ, ε) α = -K * jnp.sqrt(δε) β = -D * jnp.sqrt(δε) = (f_tangential_projected - α * m) / β # Return the 6D force in the contact frame and # the deformation derivative. return CW_f, CW_f, = jax.lax.cond( pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2, true_fun=lambda _: slipping_contact(), false_fun=lambda _: sticking_contact(), operand=None, ) # Express the 6D force in the world frame. W_f = W_Xf_CW @ CW_f # Return the 6D force in the world frame and the deformation derivative. return W_f, (,) # (W_f, (ṁ,)) return jax.lax.cond( pred=active_contact, true_fun=lambda _: below_terrain(), false_fun=lambda _: above_terrain(), operand=None, ) # (W_f, (ṁ,)) return jax.lax.cond( pred=(μ == 0.0), true_fun=lambda _: with_no_friction(), false_fun=lambda _: with_friction(), operand=None, )
[docs] @jax_dataclasses.pytree_dataclass class SoftContactsState(ContactsState): """ Class storing the state of the soft contacts model. Attributes: tangential_deformation: The matrix of 3D tangential material deformations corresponding to each collidable point. """ tangential_deformation: jtp.Matrix def __hash__(self) -> int: return hash( tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist()) ) def __eq__(self, other: SoftContactsState) -> bool: if not isinstance(other, SoftContactsState): return False return hash(self) == hash(other)
[docs] @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, tangential_deformation: jtp.Matrix | None = None, ) -> SoftContactsState: """ Build a `SoftContactsState` from a `JaxSimModel`. Args: model: The `JaxSimModel` associated with the soft contacts state. tangential_deformation: The matrix of 3D tangential material deformations. Returns: The `SoftContactsState` built from the `JaxSimModel`. Note: If any of the state components are not provided, they are built from the `JaxSimModel` and initialized to zero. """ return SoftContactsState.build( tangential_deformation=tangential_deformation, number_of_collidable_points=len( model.kin_dyn_parameters.contact_parameters.body ), )
[docs] @staticmethod def build( tangential_deformation: jtp.Matrix | None = None, number_of_collidable_points: int | None = None, ) -> SoftContactsState: """ Create a `SoftContactsState`. Args: tangential_deformation: The matrix of 3D tangential material deformations corresponding to each collidable point. number_of_collidable_points: The number of collidable points. Returns: A `SoftContactsState` instance. """ tangential_deformation = ( tangential_deformation if tangential_deformation is not None else jnp.zeros(shape=(number_of_collidable_points, 3)) ) if tangential_deformation.shape[1] != 3: raise RuntimeError("The tangential deformation matrix must have 3 columns.") if ( number_of_collidable_points is not None and tangential_deformation.shape[0] != number_of_collidable_points ): msg = "The number of collidable points must match the number of rows " msg += "in the tangential deformation matrix." raise RuntimeError(msg) return SoftContactsState( tangential_deformation=jnp.array(tangential_deformation).astype(float) )
[docs] @staticmethod def zero(model: js.model.JaxSimModel) -> SoftContactsState: """ Build a zero `SoftContactsState` from a `JaxSimModel`. Args: model: The `JaxSimModel` associated with the soft contacts state. Returns: A zero `SoftContactsState` instance. """ return SoftContactsState.build_from_jaxsim_model(model=model)
[docs] def valid(self, model: js.model.JaxSimModel) -> bool: """ Check if the `SoftContactsState` is valid for a given `JaxSimModel`. Args: model: The `JaxSimModel` to validate the `SoftContactsState` against. Returns: `True` if the soft contacts state is valid for the given `JaxSimModel`, `False` otherwise. """ shape = self.tangential_deformation.shape expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3) if shape != expected: return False return True