import abc
import contextlib
import dataclasses
import functools
from collections.abc import Callable, Iterator, Sequence
from typing import Any, ClassVar
import jax.flatten_util
import jax_dataclasses
import jaxsim.typing as jtp
from . import Mutability
try:
from typing import Self
except ImportError:
from typing_extensions import Self
[docs]
@jax_dataclasses.pytree_dataclass
class JaxsimDataclass(abc.ABC):
"""Class extending `jax_dataclasses.pytree_dataclass` instances with utilities."""
# This attribute is set by jax_dataclasses
__mutability__: ClassVar[Mutability] = Mutability.FROZEN
[docs]
@contextlib.contextmanager
def editable(self: Self, validate: bool = True) -> Iterator[Self]:
"""
Context manager to operate on a mutable copy of the object.
Args:
validate: Whether to validate the output PyTree upon exiting the context.
Yields:
A mutable copy of the object.
Note:
This context manager is useful to operate on an r/w copy of a PyTree making
sure that the output object does not trigger JIT recompilations.
"""
mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)
with self.copy().mutable_context(mutability=mutability) as obj:
yield obj
[docs]
@contextlib.contextmanager
def mutable_context(
self: Self,
mutability: Mutability = Mutability.MUTABLE,
restore_after_exception: bool = True,
) -> Iterator[Self]:
"""
Context manager to temporarily change the mutability of the object.
Args:
mutability: The mutability to set.
restore_after_exception:
Whether to restore the original object in case of an exception
occurring within the context.
Yields:
The object with the new mutability.
Note:
This context manager is useful to operate in place on a PyTree without
the need to make a copy while optionally keeping active the checks on
the PyTree structure, shapes, and dtypes.
"""
if restore_after_exception:
self_copy = self.copy()
original_mutability = self.mutability()
original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
original_structure = jax.tree_util.tree_structure(tree=self)
def restore_self() -> None:
self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)
for f in dataclasses.fields(self_copy):
setattr(self, f.name, getattr(self_copy, f.name))
try:
self.set_mutability(mutability=mutability)
yield self
if mutability is not Mutability.MUTABLE_NO_VALIDATION:
new_structure = jax.tree_util.tree_structure(tree=self)
if original_structure != new_structure:
msg = "Pytree structure has changed from {} to {}"
raise ValueError(msg.format(original_structure, new_structure))
new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
if original_shapes != new_shapes:
msg = "Leaves shapes have changed from {} to {}"
raise ValueError(msg.format(original_shapes, new_shapes))
new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
if original_dtypes != new_dtypes:
msg = "Leaves dtypes have changed from {} to {}"
raise ValueError(msg.format(original_dtypes, new_dtypes))
new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
if original_weak_types != new_weak_types:
msg = "Leaves weak types have changed from {} to {}"
raise ValueError(msg.format(original_weak_types, new_weak_types))
except Exception as e:
if restore_after_exception:
restore_self()
self.set_mutability(original_mutability)
raise e
finally:
self.set_mutability(original_mutability)
[docs]
@staticmethod
def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
"""
Helper method to get the leaf shapes of a PyTree.
Args:
tree: The PyTree to consider.
Returns:
A tuple containing the leaf shapes of the PyTree or `None` is the leaf is
not a numpy-like array.
"""
return tuple(
map(
lambda leaf: getattr(leaf, "shape", None),
jax.tree_util.tree_leaves(tree),
)
)
[docs]
@staticmethod
def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
"""
Helper method to get the leaf dtypes of a PyTree.
Args:
tree: The PyTree to consider.
Returns:
A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is
not a numpy-like array.
"""
return tuple(
map(
lambda leaf: getattr(leaf, "dtype", None),
jax.tree_util.tree_leaves(tree),
)
)
[docs]
@staticmethod
def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
"""
Helper method to get the leaf weak types of a PyTree.
Args:
tree: The PyTree to consider.
Returns:
A tuple marking whether the leaf contains a JAX array with weak type.
"""
return tuple(
map(
lambda leaf: getattr(leaf, "weak_type", None),
jax.tree_util.tree_leaves(tree),
)
)
[docs]
@staticmethod
def check_compatibility(*trees: Sequence[Any]) -> None:
"""
Check whether the PyTrees are compatible in structure, shape, and dtype.
Args:
*trees: The PyTrees to compare.
Raises:
ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.
"""
target_structure = jax.tree_util.tree_structure(trees[0])
compatible_structure = functools.reduce(
lambda compatible, tree: compatible
and jax.tree_util.tree_structure(tree) == target_structure,
trees[1:],
True,
)
if not compatible_structure:
raise ValueError("Pytrees have incompatible structures.")
target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])
compatible_shapes = functools.reduce(
lambda compatible, tree: compatible
and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,
trees[1:],
True,
)
if not compatible_shapes:
raise ValueError("Pytrees have incompatible shapes.")
target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])
compatible_dtypes = functools.reduce(
lambda compatible, tree: compatible
and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,
trees[1:],
True,
)
if not compatible_dtypes:
raise ValueError("Pytrees have incompatible dtypes.")
[docs]
def is_mutable(self, validate: bool = False) -> bool:
"""
Check whether the object is mutable.
Args:
validate: Additionally checks if the object also has validation enabled.
Returns:
True if the object is mutable, False otherwise.
"""
return (
self.__mutability__ is Mutability.MUTABLE
if validate
else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION
)
[docs]
def mutability(self) -> Mutability:
"""
Get the mutability type of the object.
Returns:
The mutability type of the object.
"""
return self.__mutability__
[docs]
def set_mutability(self, mutability: Mutability) -> None:
"""
Set the mutability of the object in-place.
Args:
mutability: The desired mutability type.
"""
jax_dataclasses._copy_and_mutate._mark_mutable(
self, mutable=mutability, visited=set()
)
[docs]
def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
"""
Return a mutable reference of the object.
Args:
mutable: Whether to make the object mutable.
validate: Whether to enable validation on the object.
Returns:
A mutable reference of the object.
"""
if mutable:
mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)
else:
mutability = Mutability.FROZEN
self.set_mutability(mutability=mutability)
return self
[docs]
def copy(self: Self) -> Self:
"""
Return a copy of the object.
Returns:
A copy of the object.
"""
# Make a copy calling tree_map.
obj = jax.tree.map(lambda leaf: leaf, self)
# Make sure that the copied object and all the copied leaves have the same
# mutability of the original object.
obj.set_mutability(mutability=self.mutability())
return obj
[docs]
def replace(self: Self, validate: bool = True, **kwargs) -> Self:
"""
Return a new object replacing in-place the specified fields with new values.
Args:
validate: Whether to validate that the new fields do not alter the PyTree.
**kwargs: The fields to replace.
Returns:
A reference of the object with the specified fields replaced.
"""
# Use the dataclasses replace method.
obj = dataclasses.replace(self, **kwargs)
if validate:
JaxsimDataclass.check_compatibility(self, obj)
# Make sure that all the new leaves have the same mutability of the object.
obj.set_mutability(mutability=self.mutability())
return obj
[docs]
def flatten(self) -> jtp.Vector:
"""
Flatten the object into a 1D vector.
Returns:
A 1D vector containing the flattened object.
"""
return self.flatten_fn()(self)
[docs]
@classmethod
def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]:
"""
Return a function to flatten the object into a 1D vector.
Returns:
A function to flatten the object into a 1D vector.
"""
return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]
[docs]
def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]:
"""
Return a function to unflatten a 1D vector into the object.
Returns:
A function to unflatten a 1D vector into the object.
Notes:
Due to JAX internals, the function to unflatten a PyTree needs to be
created from an existing instance of the PyTree.
"""
return jax.flatten_util.ravel_pytree(self)[1]