Source code for jaxsim.math.rotation

import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.typing as jtp

from .skew import Skew


[docs] class Rotation:
[docs] @staticmethod def x(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the X-axis. Args: theta (jtp.Float): Rotation angle in radians. Returns: jtp.Matrix: 3D rotation matrix. """ return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
[docs] @staticmethod def y(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the Y-axis. Args: theta (jtp.Float): Rotation angle in radians. Returns: jtp.Matrix: 3D rotation matrix. """ return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
[docs] @staticmethod def z(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the Z-axis. Args: theta (jtp.Float): Rotation angle in radians. Returns: jtp.Matrix: 3D rotation matrix. """ return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
[docs] @staticmethod def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: """ Generate a 3D rotation matrix from an axis-angle representation. Args: vector (jtp.Vector): Axis-angle representation as a 3D vector. Returns: jtp.Matrix: 3D rotation matrix. """ vector = vector.squeeze() theta = jnp.linalg.norm(vector) def theta_is_not_zero(theta_and_v: tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix: theta, v = theta_and_v s = jnp.sin(theta) c = jnp.cos(theta) c1 = 2 * jnp.sin(theta / 2.0) ** 2 u = v / theta u = jnp.vstack(u.squeeze()) R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T return R.transpose() return jax.lax.cond( pred=(theta == 0.0), true_fun=lambda operand: jnp.eye(3), false_fun=theta_is_not_zero, operand=(theta, vector), )