Utils#
- class jaxsim.utils.JaxsimDataclass[source]#
Class extending jax_dataclasses.pytree_dataclass instances with utilities.
- static check_compatibility(*trees)[source]#
Check whether the PyTrees are compatible in structure, shape, and dtype.
- Parameters:
*trees (
Sequence
[Any
]) – The PyTrees to compare.- Raises:
ValueError – If the PyTrees have incompatible structures, shapes, or dtypes.
- Return type:
None
- copy()[source]#
Return a copy of the object.
- Return type:
Self
- Returns:
A copy of the object.
- Parameters:
self (Self)
- editable(validate=True)[source]#
Context manager to operate on a mutable copy of the object.
- Parameters:
validate (
bool
) – Whether to validate the output PyTree upon exiting the context.self (Self)
- Yields:
A mutable copy of the object.
- Return type:
Iterator
[Self
]
Note
This context manager is useful to operate on an r/w copy of a PyTree making sure that the output object does not trigger JIT recompilations.
- flatten()[source]#
Flatten the object into a 1D vector.
- Return type:
Array
- Returns:
A 1D vector containing the flattened object.
- classmethod flatten_fn()[source]#
Return a function to flatten the object into a 1D vector.
- Return type:
Callable
[[Self
],Array
]- Returns:
A function to flatten the object into a 1D vector.
- static get_leaf_dtypes(tree)[source]#
Helper method to get the leaf dtypes of a PyTree.
- Parameters:
tree (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The PyTree to consider.- Return type:
tuple
- Returns:
A tuple containing the leaf dtypes of the PyTree or None is the leaf is not a numpy-like array.
- static get_leaf_shapes(tree)[source]#
Helper method to get the leaf shapes of a PyTree.
- Parameters:
tree (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The PyTree to consider.- Return type:
tuple
[tuple
[int
,...
] |None
]- Returns:
A tuple containing the leaf shapes of the PyTree or None is the leaf is not a numpy-like array.
- static get_leaf_weak_types(tree)[source]#
Helper method to get the leaf weak types of a PyTree.
- Parameters:
tree (
dict
[Hashable
,TypeVar
(PyTree
)] |list
[TypeVar
(PyTree
)] |tuple
[TypeVar
(PyTree
)] |None
|Array
|Any
) – The PyTree to consider.- Return type:
tuple
[bool
,...
]- Returns:
A tuple marking whether the leaf contains a JAX array with weak type.
- is_mutable(validate=False)[source]#
Check whether the object is mutable.
- Parameters:
validate (
bool
) – Additionally checks if the object also has validation enabled.- Return type:
bool
- Returns:
True if the object is mutable, False otherwise.
- mutability()[source]#
Get the mutability type of the object.
- Return type:
_Mutability
- Returns:
The mutability type of the object.
- mutable(mutable=True, validate=False)[source]#
Return a mutable reference of the object.
- Parameters:
mutable (
bool
) – Whether to make the object mutable.validate (
bool
) – Whether to enable validation on the object.self (Self)
- Return type:
Self
- Returns:
A mutable reference of the object.
- mutable_context(mutability=_Mutability.MUTABLE, restore_after_exception=True)[source]#
Context manager to temporarily change the mutability of the object.
- Parameters:
mutability (
_Mutability
) – The mutability to set.restore_after_exception (
bool
) – Whether to restore the original object in case of an exception occurring within the context.self (Self)
- Yields:
The object with the new mutability.
- Return type:
Iterator
[Self
]
Note
This context manager is useful to operate in place on a PyTree without the need to make a copy while optionally keeping active the checks on the PyTree structure, shapes, and dtypes.
- replace(validate=True, **kwargs)[source]#
Return a new object replacing in-place the specified fields with new values.
- Parameters:
validate (
bool
) – Whether to validate that the new fields do not alter the PyTree.**kwargs – The fields to replace.
self (Self)
- Return type:
Self
- Returns:
A reference of the object with the specified fields replaced.
- set_mutability(mutability)[source]#
Set the mutability of the object in-place.
- Parameters:
mutability (
_Mutability
) – The desired mutability type.- Return type:
None
- unflatten_fn()[source]#
Return a function to unflatten a 1D vector into the object.
- Return type:
Callable
[[Array
],Self
]- Returns:
A function to unflatten a 1D vector into the object.
- Parameters:
self (Self)
Notes
Due to JAX internals, the function to unflatten a PyTree needs to be created from an existing instance of the PyTree.