Source code for jaxsim.integrators.variable_step

import functools
from typing import Any, ClassVar, Generic

try:
    from typing import Self
except ImportError:
    from typing_extensions import Self

import jax
import jax.flatten_util
import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static

from jaxsim import typing as jtp
from jaxsim.utils import Mutability

from .common import (
    ExplicitRungeKutta,
    ExplicitRungeKuttaSO3Mixin,
    NextState,
    PyTreeType,
    State,
    StateDerivative,
    SystemDynamics,
    Time,
    TimeStep,
)

# For robot dynamics, the following default tolerances are already pretty accurate.
# Users can either decrease them and pay the price of smaller Δt, or increase
# them and pay the price of less accurate dynamics.
RTOL_DEFAULT = 0.000_100  # 0.01%
ATOL_DEFAULT = 0.000_010  # 10μ

# Default parameters of Embedded RK schemes.
SAFETY_DEFAULT = 0.9
BETA_MIN_DEFAULT = 1.0 / 10
BETA_MAX_DEFAULT = 2.5
MAX_STEP_REJECTIONS_DEFAULT = 5


# =================
# Utility functions
# =================


[docs] @functools.partial(jax.jit, static_argnames=["f"]) def estimate_step_size( x0: jtp.PyTree, t0: Time, f: SystemDynamics, order: jtp.IntLike, rtol: jtp.FloatLike = RTOL_DEFAULT, atol: jtp.FloatLike = ATOL_DEFAULT, ) -> tuple[jtp.Float, jtp.PyTree]: r""" Compute the initial step size to warm-start variable-step integrators. Args: x0: The initial state. t0: The initial time. f: The state derivative function :math:`f(x, t)`. order: The order :math:`p` of an integrator with truncation error :math:`\mathcal{O}(\Delta t^{p+1})`. rtol: The relative tolerance to scale the state. atol: The absolute tolerance to scale the state. Returns: A tuple containing the computed initial step size and the state derivative :math:`\dot{x} = f(x_0, t_0)`. Note: Interested readers could find implementation details in: Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4. E. Hairer, S. P. Norsett G. Wanner. """ # Helper to flatten a pytree to a 1D vector. def flatten(pytree) -> jax.Array: return jax.flatten_util.ravel_pytree(pytree=pytree)[0] # Compute the state derivative at the initial state. ẋ0 = f(x0, t0)[0] # Compute the scaling factors of the initial state and its derivative. compute_scale = lambda x: atol + jnp.abs(x) * rtol scale0 = jax.tree.map(compute_scale, x0) scale1 = jax.tree.map(compute_scale, ẋ0) # Scale the initial state and its derivative. scale_pytree = lambda x, scale: jnp.abs(x) / scale x0_scaled = jax.tree.map(scale_pytree, x0, scale0) ẋ0_scaled = jax.tree.map(scale_pytree, ẋ0, scale1) # Get the maximum of the scaled pytrees. d0 = jnp.linalg.norm(flatten(x0_scaled), ord=jnp.inf) d1 = jnp.linalg.norm(flatten(ẋ0_scaled), ord=jnp.inf) # Compute the first guess of the initial step size. h0 = jnp.where(jnp.minimum(d0, d1) <= 1e-5, 1e-6, 0.01 * d0 / d1) # Compute the next state (explicit Euler step) and its derivative. x1 = jax.tree.map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0) ẋ1 = f(x1, t0 + h0)[0] # Compute the scaling factor of the state derivatives. compute_scale_2 = lambda ẋ0, ẋ1: atol + jnp.maximum(jnp.abs(ẋ0), jnp.abs(ẋ1)) * rtol scale2 = jax.tree.map(compute_scale_2, ẋ0, ẋ1) # Scale the difference of the state derivatives. scale_ẋ_difference = lambda ẋ0, ẋ1, scale: jnp.abs((ẋ0 - ẋ1) / scale) ẋ_difference_scaled = jax.tree.map(scale_ẋ_difference, ẋ0, ẋ1, scale2) # Get the maximum of the scaled derivatives difference. d2 = jnp.linalg.norm(flatten(ẋ_difference_scaled), ord=jnp.inf) / h0 # Compute the second guess of the initial step size. h1 = jnp.where( jnp.maximum(d1, d2) <= 1e-15, jnp.maximum(1e-6, h0 * 1e-3), (0.01 / jnp.maximum(d1, d2)) ** (1.0 / (order + 1.0)), ) # Propose the final guess of the initial step size. # Also return the state derivative computed at the initial state since # likely it is a quantity that needs to be computed again later. return jnp.array(jnp.minimum(100.0 * h0, h1), dtype=float), ẋ0
[docs] @jax.jit def compute_pytree_scale( x1: jtp.PyTree, x2: jtp.PyTree | None = None, rtol: jtp.FloatLike = RTOL_DEFAULT, atol: jtp.FloatLike = ATOL_DEFAULT, ) -> jtp.PyTree: """ Compute the component-wise state scale factors to scale dynamical states. Args: x1: The first state (often the initial state). x2: The optional second state (often the final state). rtol: The relative tolerance to scale the state. atol: The absolute tolerance to scale the state. Returns: A pytree with the same structure of the state containing the scaling factors. """ # Consider a zero second pytree, if not given. x2 = jax.tree.map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2 # Compute the scaling factors of the initial state and its derivative. compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol scale = jax.tree.map(compute_scale, x1, x2) return scale
[docs] @jax.jit def local_error_estimation( xf: jtp.PyTree, xf_estimate: jtp.PyTree | None = None, x0: jtp.PyTree | None = None, rtol: jtp.FloatLike = RTOL_DEFAULT, atol: jtp.FloatLike = ATOL_DEFAULT, norm_ord: jtp.IntLike | jtp.FloatLike = jnp.inf, ) -> jtp.Float: """ Estimate the local integration error, often used in Embedded RK schemes. Args: xf: The final state, often computed with the most accurate integrator. xf_estimate: The estimated final state, often computed with the less accurate integrator. If missing, it is initialized to zero. x0: The initial state to compute the scaling factors. If missing, it is initialized to zero. rtol: The relative tolerance to scale the state. atol: The absolute tolerance to scale the state. norm_ord: The norm to use to compute the error. Default is the infinity norm. Returns: The estimated local integration error. """ # Helper to flatten a pytree to a 1D vector. def flatten(pytree) -> jax.Array: return jax.flatten_util.ravel_pytree(pytree=pytree)[0] # Compute the scale considering the initial and final states. scale = compute_pytree_scale(x1=xf, x2=x0, rtol=rtol, atol=atol) # Consider a zero estimated final state, if not given. xf_estimate = ( jax.tree.map(lambda l: jnp.zeros_like(l), xf) if xf_estimate is None else xf_estimate ) # Estimate the error. estimate_error = lambda l, , sc: jnp.abs(l - ) / sc error_estimate = jax.tree.map(estimate_error, xf, xf_estimate, scale) # Return the highest element of the error estimate. return jnp.linalg.norm(flatten(error_estimate), ord=norm_ord)
# ================================ # Embedded Runge-Kutta integrators # ================================
[docs] @jax_dataclasses.pytree_dataclass class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): # Define the row of the integration output corresponding to the solution estimate. # This is the row of b.T that produces the state used e.g. by embedded methods to # implement the adaptive timestep logic. row_index_of_solution_estimate: ClassVar[int | None] = None # Bounds of the adaptive Δt. dt_max: Static[jtp.FloatLike] = jnp.inf dt_min: Static[jtp.FloatLike] = -jnp.inf # Tolerances used to scale the two states corresponding to the high-order solution # and the low-order estimate during the computation of the local integration error. rtol: Static[jtp.FloatLike] = RTOL_DEFAULT atol: Static[jtp.FloatLike] = ATOL_DEFAULT # Parameters of the adaptive timestep logic. # Refer to Eq. (4.13) pag. 168 of Hairer93. safety: Static[jtp.FloatLike] = SAFETY_DEFAULT beta_max: Static[jtp.FloatLike] = BETA_MAX_DEFAULT beta_min: Static[jtp.FloatLike] = BETA_MIN_DEFAULT # Maximum number of rejected steps when the Δt needs to be reduced. max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
[docs] def init( self, x0: State, t0: Time, dt: TimeStep | None = None, *, include_dynamics_aux_dict: bool = False, **kwargs, ) -> dict[str, Any]: # In these type of integrators, it's not relevant picking a meaningful dt. # We just need to execute __call__ once to initialize the dictionary of params. return super().init( x0=x0, t0=t0, dt=0.001, include_dynamics_aux_dict=include_dynamics_aux_dict, **kwargs, )
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState: # This method is called differently in three stages: # # 1. During initialization, to allocate a dummy params dictionary. # 2. During the first step, to compute the initial valid params dictionary. # 3. After the first step, to compute the next state and the next valid params. # # Stage 1 produces a zero-filled dummy dictionary. # Stage 2 receives a dummy dictionary and produces valid parameters that can be # fed to later steps. # Stage 3 corresponds to any consecutive step after the first one. It can re-use # data (like for FSAL) from previous steps. # integrator_init = self.params.get(self.InitializingKey, jnp.array(False)) integrator_first_step = self.params.get(self.AfterInitKey, jnp.array(False)) # Close f over optional kwargs. f = lambda x, t: self.dynamics(x=x, t=t, **kwargs) # Define the final time. tf = t0 + dt # Initialize solution orders. p = self.order_of_solution = self.order_of_solution_estimate q = jnp.minimum(p, ) # In Stage 1 and 2, estimate from scratch dt0 and dxdt0. # In Stage 3, dt0 is taken from the previous step. If the integrator supports # FSAL, dxdt0 is taken from the previous step. Otherwise, it is computed by # evaluating the dynamics. self.params["dt0"], self.params["dxdt0"] = jax.lax.cond( pred=jnp.logical_or("dt0" not in self.params, integrator_first_step), true_fun=lambda params: estimate_step_size( x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol ), false_fun=lambda params: ( params.get("dt0", jnp.array(0).astype(float)), self.params.get("dxdt0", f(x0, t0)[0]), ), operand=self.params, ) # If the integrator does not support FSAL, it is useless to store dxdt0. if not self.has_fsal: _ = self.params.pop("dxdt0") # Clip the estimated initial step size to the given bounds, if necessary. self.params["dt0"] = jnp.clip( self.params["dt0"], jnp.minimum(self.dt_min, self.params["dt0"]), jnp.minimum(self.dt_max, self.params["dt0"]), ) # ========================================================= # While loop to reach tf from t0 using an adaptive timestep # ========================================================= # Initialize the carry of the while loop. Carry = tuple[Any, ...] carry0: Carry = ( x0, jnp.array(t0).astype(float), self.params, jnp.array(0, dtype=int), jnp.array(False).astype(bool), ) def while_loop_cond(carry: Carry) -> jtp.Bool: _, _, _, _, break_loop = carry return jnp.logical_not(break_loop) # Each loop is an integration step with variable Δt. # Depending on the integration error, the step could be discarded and the # while body ran again from the same (x0, t0) but with a smaller Δt. # We run these loops until the final time tf is reached. def while_loop_body(carry: Carry) -> Carry: # Unpack the carry. x0, t0, params, discarded_steps, _ = carry # Take care of the final adaptive step. # We want the final Δt to let us reach tf exactly. # Then we can exit the while loop. Δt0 = params["dt0"] Δt0 = jnp.where(t0 + Δt0 < tf, Δt0, tf - t0) break_loop = jnp.where(t0 + Δt0 < tf, False, True) # Run the underlying explicit RK integrator. # The output z contains multiple solutions (depending on the rows of b.T). with self.editable(validate=True) as integrator: integrator.params = params z = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs) params_next = integrator.params # Extract the high-order solution xf and the low-order estimate x̂f. xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z) x̂f = jax.tree.map(lambda l: l[self.row_index_of_solution_estimate], z) # Calculate the local integration error. local_error = local_error_estimation( x0=x0, xf=xf, xf_estimate=x̂f, rtol=self.rtol, atol=self.atol ) # Shrink the Δt every time by the safety factor (even when accepted). # The β parameters define the bounds of the timestep update factor. safety = jnp.clip(self.safety, 0.0, 1.0) β_min = jnp.maximum(0.0, self.beta_min) β_max = jnp.maximum(β_min, self.beta_max) # Compute the next Δt from the desired integration error. # The computed integration step is accepted if error <= 1.0, # otherwise it is rejected. # # In case of rejection, Δt_next is always smaller than Δt0. # In case of acceptance, Δt_next could either be larger than Δt0, # or slightly smaller than Δt0 depending on the safety factor. Δt_next = Δt0 * jnp.clip( safety * jnp.power(1 / local_error, 1 / (q + 1)), β_min, β_max, ) def accept_step(): # Use Δt_next in the next while loop. # If it is the last one, and Δt0 was clipped, return the initial Δt0. params_next_accepted = params_next | dict( dt0=jnp.clip( jax.lax.select( pred=break_loop, on_true=params["dt0"], on_false=Δt_next, ), self.dt_min, self.dt_max, ) ) # Start the next while loop from the final state. x0_next = xf # Advance the starting time of the next adaptive step. t0_next = t0 + Δt0 # Signal that the final time has been reached. break_loop_next = t0 + Δt0 >= tf return ( x0_next, t0_next, break_loop_next, params_next_accepted, jnp.array(0, dtype=int), ) def reject_step(): # Get back the original params. params_next_rejected = params # This time, with a reduced Δt. params_next_rejected["dt0"] = jnp.clip( Δt_next, self.dt_min, self.dt_max ) return ( x0, t0, False, params_next_rejected, discarded_steps + 1, ) # Decide whether to accept or reject the step. ( x0_next, t0_next, break_loop, params_next, discarded_steps, ) = jax.lax.cond( pred=jnp.array( [ discarded_steps >= self.max_step_rejections, local_error <= 1.0, Δt_next < self.dt_min, integrator_init, ] ).any(), true_fun=accept_step, false_fun=reject_step, ) return ( x0_next, t0_next, params_next, discarded_steps, break_loop, ) # Integrate with adaptive step until tf is reached. ( xf, tf, params_tf, _, _, ) = jax.lax.while_loop( cond_fun=while_loop_cond, body_fun=while_loop_body, init_val=carry0, ) # Store the parameters. # They will be returned to the caller in a functional way in the step method. with self.mutable_context(mutability=Mutability.MUTABLE): self.params = params_tf return xf @property def order_of_solution(self) -> int: return self.order_of_bT_rows[self.row_index_of_solution] @property def order_of_solution_estimate(self) -> int: return self.order_of_bT_rows[self.row_index_of_solution_estimate]
[docs] @classmethod def build( cls: type[Self], *, dynamics: SystemDynamics[State, StateDerivative], fsal_enabled_if_supported: jtp.BoolLike = True, dt_max: jtp.FloatLike = jnp.inf, dt_min: jtp.FloatLike = -jnp.inf, rtol: jtp.FloatLike = RTOL_DEFAULT, atol: jtp.FloatLike = ATOL_DEFAULT, safety: jtp.FloatLike = SAFETY_DEFAULT, beta_max: jtp.FloatLike = BETA_MAX_DEFAULT, beta_min: jtp.FloatLike = BETA_MIN_DEFAULT, max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT, **kwargs, ) -> Self: # Check that b.T has enough rows based on the configured index of the # solution estimate. This is necessary for embedded methods. if ( cls.row_index_of_solution_estimate is not None and cls.row_index_of_solution_estimate >= cls.b.T.shape[0] ): msg = "The index of the solution estimate ({}-th row of `b.T`) " msg += "is out of range ({})." raise ValueError( msg.format(cls.row_index_of_solution_estimate, cls.b.T.shape[0]) ) integrator = super().build( # Integrator: dynamics=dynamics, # ExplicitRungeKutta: fsal_enabled_if_supported=bool(fsal_enabled_if_supported), # EmbeddedRungeKutta: dt_max=float(dt_max), dt_min=float(dt_min), rtol=float(rtol), atol=float(atol), safety=float(safety), beta_max=float(beta_max), beta_min=float(beta_min), max_step_rejections=int(max_step_rejections), **kwargs, ) return integrator
[docs] @jax_dataclasses.pytree_dataclass class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0], [1, 0], ] ).astype(float) b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array( [ [1 / 2, 1 / 2], [1, 0], ] ), ) .astype(float) .transpose() ) c: ClassVar[jtp.Vector] = jnp.array( [0, 1], ).astype(float) row_index_of_solution: ClassVar[int] = 0 row_index_of_solution_estimate: ClassVar[int | None] = 1 order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
[docs] @jax_dataclasses.pytree_dataclass class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0, 0, 0], [1 / 2, 0, 0, 0], [0, 3 / 4, 0, 0], [2 / 9, 1 / 3, 4 / 9, 0], ] ).astype(float) b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array( [ [2 / 9, 1 / 3, 4 / 9, 0], [7 / 24, 1 / 4, 1 / 3, 1 / 8], ] ), ) .astype(float) .transpose() ) c: ClassVar[jtp.Vector] = jnp.array( [0, 1 / 2, 3 / 4, 1], ).astype(float) row_index_of_solution: ClassVar[int] = 0 row_index_of_solution_estimate: ClassVar[int | None] = 1 order_of_bT_rows: ClassVar[tuple[int, ...]] = (3, 2)