Source code for jaxsim.integrators.fixed_step

from typing import ClassVar, Generic

import jax.numpy as jnp
import jax_dataclasses

import jaxsim.api as js
import jaxsim.typing as jtp

from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType

ODEStateDerivative = js.ode_data.ODEState

# =====================================================
# Explicit Runge-Kutta integrators operating on PyTrees
# =====================================================


[docs] @jax_dataclasses.pytree_dataclass class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float) b: ClassVar[jtp.Matrix] = jnp.atleast_2d(1).astype(float).transpose() c: ClassVar[jtp.Vector] = jnp.atleast_1d(0).astype(float) row_index_of_solution: ClassVar[int] = 0 order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
[docs] @jax_dataclasses.pytree_dataclass class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0], [1, 0], ] ).astype(float) b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array([1 / 2, 1 / 2]), ) .astype(float) .transpose() ) c: ClassVar[jtp.Vector] = jnp.array( [0, 1], ).astype(float) row_index_of_solution: ClassVar[int] = 0 order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
[docs] @jax_dataclasses.pytree_dataclass class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0, 0, 0], [1 / 2, 0, 0, 0], [0, 1 / 2, 0, 0], [0, 0, 1, 0], ] ).astype(float) b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]), ) .astype(float) .transpose() ) c: ClassVar[jtp.Vector] = jnp.array( [0, 1 / 2, 1 / 2, 1], ).astype(float) row_index_of_solution: ClassVar[int] = 0 order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
# =============================================================================== # Explicit Runge-Kutta integrators operating on ODEState and integrating on SO(3) # ===============================================================================
[docs] @jax_dataclasses.pytree_dataclass class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]): pass
[docs] @jax_dataclasses.pytree_dataclass class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]): pass
[docs] @jax_dataclasses.pytree_dataclass class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]): pass