Utils#

class jaxsim.utils.JaxsimDataclass[source]#

Class extending jax_dataclasses.pytree_dataclass instances with utilities.

static check_compatibility(*trees)[source]#

Check whether the PyTrees are compatible in structure, shape, and dtype.

Parameters:

*trees (Sequence[Any]) – The PyTrees to compare.

Raises:

ValueError – If the PyTrees have incompatible structures, shapes, or dtypes.

Return type:

None

copy()[source]#

Return a copy of the object.

Return type:

Self

Returns:

A copy of the object.

Parameters:

self (Self)

editable(validate=True)[source]#

Context manager to operate on a mutable copy of the object.

Parameters:
  • validate (bool) – Whether to validate the output PyTree upon exiting the context.

  • self (Self)

Yields:

A mutable copy of the object.

Return type:

Iterator[Self]

Note

This context manager is useful to operate on an r/w copy of a PyTree making sure that the output object does not trigger JIT recompilations.

flatten()[source]#

Flatten the object into a 1D vector.

Return type:

Array

Returns:

A 1D vector containing the flattened object.

classmethod flatten_fn()[source]#

Return a function to flatten the object into a 1D vector.

Return type:

Callable[[Self], Array]

Returns:

A function to flatten the object into a 1D vector.

static get_leaf_dtypes(tree)[source]#

Helper method to get the leaf dtypes of a PyTree.

Parameters:

tree (dict[Hashable, TypeVar(PyTree)] | list[TypeVar(PyTree)] | tuple[TypeVar(PyTree)] | None | Array | Any) – The PyTree to consider.

Return type:

tuple

Returns:

A tuple containing the leaf dtypes of the PyTree or None is the leaf is not a numpy-like array.

static get_leaf_shapes(tree)[source]#

Helper method to get the leaf shapes of a PyTree.

Parameters:

tree (dict[Hashable, TypeVar(PyTree)] | list[TypeVar(PyTree)] | tuple[TypeVar(PyTree)] | None | Array | Any) – The PyTree to consider.

Return type:

tuple[tuple[int, ...] | None]

Returns:

A tuple containing the leaf shapes of the PyTree or None is the leaf is not a numpy-like array.

static get_leaf_weak_types(tree)[source]#

Helper method to get the leaf weak types of a PyTree.

Parameters:

tree (dict[Hashable, TypeVar(PyTree)] | list[TypeVar(PyTree)] | tuple[TypeVar(PyTree)] | None | Array | Any) – The PyTree to consider.

Return type:

tuple[bool, ...]

Returns:

A tuple marking whether the leaf contains a JAX array with weak type.

is_mutable(validate=False)[source]#

Check whether the object is mutable.

Parameters:

validate (bool) – Additionally checks if the object also has validation enabled.

Return type:

bool

Returns:

True if the object is mutable, False otherwise.

mutability()[source]#

Get the mutability type of the object.

Return type:

_Mutability

Returns:

The mutability type of the object.

mutable(mutable=True, validate=False)[source]#

Return a mutable reference of the object.

Parameters:
  • mutable (bool) – Whether to make the object mutable.

  • validate (bool) – Whether to enable validation on the object.

  • self (Self)

Return type:

Self

Returns:

A mutable reference of the object.

mutable_context(mutability=_Mutability.MUTABLE, restore_after_exception=True)[source]#

Context manager to temporarily change the mutability of the object.

Parameters:
  • mutability (_Mutability) – The mutability to set.

  • restore_after_exception (bool) – Whether to restore the original object in case of an exception occurring within the context.

  • self (Self)

Yields:

The object with the new mutability.

Return type:

Iterator[Self]

Note

This context manager is useful to operate in place on a PyTree without the need to make a copy while optionally keeping active the checks on the PyTree structure, shapes, and dtypes.

replace(validate=True, **kwargs)[source]#

Return a new object replacing in-place the specified fields with new values.

Parameters:
  • validate (bool) – Whether to validate that the new fields do not alter the PyTree.

  • **kwargs – The fields to replace.

  • self (Self)

Return type:

Self

Returns:

A reference of the object with the specified fields replaced.

set_mutability(mutability)[source]#

Set the mutability of the object in-place.

Parameters:

mutability (_Mutability) – The desired mutability type.

Return type:

None

unflatten_fn()[source]#

Return a function to unflatten a 1D vector into the object.

Return type:

Callable[[Array], Self]

Returns:

A function to unflatten a 1D vector into the object.

Parameters:

self (Self)

Notes

Due to JAX internals, the function to unflatten a PyTree needs to be created from an existing instance of the PyTree.