import functools
from collections.abc import Sequence
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions
# =======================
# Index-related functions
# =======================
[docs]
@functools.partial(jax.jit, static_argnames="joint_name")
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
"""
Convert the name of a joint to its index.
Args:
model: The model to consider.
joint_name: The name of the joint.
Returns:
The index of the joint.
"""
if joint_name not in model.joint_names():
raise ValueError(f"Joint '{joint_name}' not found in the model.")
# Note: the index of the joint for RBDAs starts from 1, but the index for
# accessing the right element starts from 0. Therefore, there is a -1.
return (
jnp.array(
model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
)
.astype(int)
.squeeze()
)
[docs]
def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
"""
Convert the index of a joint to its name.
Args:
model: The model to consider.
joint_index: The index of the joint.
Returns:
The name of the joint.
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[joint_index < 0, joint_index >= model.number_of_joints()]
).any(),
msg="Invalid joint index '{idx}'",
idx=joint_index,
)
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
[docs]
@functools.partial(jax.jit, static_argnames="joint_names")
def names_to_idxs(
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
) -> jax.Array:
"""
Convert a sequence of joint names to their corresponding indices.
Args:
model: The model to consider.
joint_names: The names of the joints.
Returns:
The indices of the joints.
"""
return jnp.array(
[name_to_idx(model=model, joint_name=name) for name in joint_names],
).astype(int)
[docs]
def idxs_to_names(
model: js.model.JaxSimModel,
*,
joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
) -> tuple[str, ...]:
"""
Convert a sequence of joint indices to their corresponding names.
Args:
model: The model to consider.
joint_indices: The indices of the joints.
Returns:
The names of the joints.
"""
return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
# ============
# Joint limits
# ============
[docs]
@jax.jit
def position_limit(
model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
) -> tuple[jtp.Float, jtp.Float]:
"""
Get the position limits of a joint.
Args:
model: The model to consider.
joint_index: The index of the joint.
Returns:
The position limits of the joint.
"""
if model.number_of_joints() <= 1:
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
exceptions.raise_value_error_if(
condition=jnp.array(
[joint_index < 0, joint_index >= model.number_of_joints()]
).any(),
msg="Invalid joint index '{idx}'",
idx=joint_index,
)
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
return s_min.astype(float), s_max.astype(float)
[docs]
@functools.partial(jax.jit, static_argnames=["joint_names"])
def position_limits(
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Get the position limits of a list of joint.
Args:
model: The model to consider.
joint_names: The names of the joints.
Returns:
The position limits of the joints.
"""
joint_names = joint_names if joint_names is not None else model.joint_names()
if len(joint_names) == 0:
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
# ======================
# Random data generation
# ======================
[docs]
@functools.partial(jax.jit, static_argnames=["joint_names"])
def random_joint_positions(
model: js.model.JaxSimModel,
*,
joint_names: Sequence[str] | None = None,
key: jax.Array | None = None,
) -> jtp.Vector:
"""
Generate random joint positions.
Args:
model: The model to consider.
joint_names: The names of the considered joints (all if None).
key: The random key (initialized from seed 0 if None).
Note:
If the joint range or revolute joints is larger than 2π, their joint positions
will be sampled from an interval of size 2π.
Returns:
The random joint positions.
"""
# Consider the key corresponding to a zero seed if it was not passed.
key = key if key is not None else jax.random.PRNGKey(seed=0)
# Get the joint limits parsed from the model description.
s_min, s_max = position_limits(model=model, joint_names=joint_names)
# Get the joint indices.
# Note that it will trigger an exception if the given `joint_names` are not valid.
joint_names = joint_names if joint_names is not None else model.joint_names()
joint_indices = names_to_idxs(model=model, joint_names=joint_names)
from jaxsim.parsers.descriptions.joint import JointType
# Filter for revolute joints.
is_revolute = jnp.where(
jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
== JointType.Revolute,
True,
False,
)
# Shorthand for π.
π = jnp.pi
# Filter for revolute with full range (or continuous).
is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
# Clip the lower limit to -π if the joint range is larger than [-π, π].
s_min = jnp.where(
jnp.logical_and(
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
),
-π,
s_min,
)
# Clip the upper limit to +π if the joint range is larger than [-π, π].
s_max = jnp.where(
jnp.logical_and(
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
),
π,
s_max,
)
# Shift the lower limit if the upper limit is smaller than +π.
s_min = jnp.where(
jnp.logical_and(is_revolute_full_range, s_max < π),
s_max - 2 * π,
s_min,
)
# Shift the upper limit if the lower limit is larger than -π.
s_max = jnp.where(
jnp.logical_and(is_revolute_full_range, s_min > -π),
s_min + 2 * π,
s_max,
)
# Sample the joint positions.
s_random = jax.random.uniform(
minval=s_min,
maxval=s_max,
key=key,
shape=s_min.shape,
)
return s_random