JAXsim Showcase: PD Controller

JAXsim Showcase: PD Controller#

First, we install the necessary packages and import them.

Open In Colab
# @title Imports and setup
from IPython.display import clear_output
import os
import sys

IS_COLAB = "google.colab" in sys.modules

# Install JAX and Gazebo SDF
if IS_COLAB:
    !{sys.executable} -m pip install -qU jaxsim[viz]
    !apt install -qq lsb-release wget gnupg
    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
    !echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
    !apt -qq update
    !apt install -qq --no-install-recommends libsdformat13 gz-tools2

    # Install dependencies for visualization on Colab and ReadTheDocs
    !sudo apt update
    !apt install libosmesa6-dev
    clear_output()


import jax
import jax.numpy as jnp
from jaxsim import logging

logging.set_logging_level(logging.LoggingLevel.INFO)
logging.info(f"Running on {jax.devices()}")
jaxsim[3372] INFO Running on [CpuDevice(id=0)]

We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart.

# @title Fetch the URDF file
import requests

url = "https://raw.githubusercontent.com/ami-iit/jaxsim/main/examples/assets/cartpole.urdf"

response = requests.get(url)
if response.status_code == 200:
    model_urdf_string = response.text
else:
    logging.error("Failed to fetch data")

JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:

  • model: an object that defines the dynamics of the system.

  • data: an object that contains the state of the system.

  • integrator: an object that defines the integration method.

  • integrator_state: an object that contains the state of the integrator.

import jaxsim.api as js
from jaxsim import integrators

dt = 0.01
integration_time = 5.0
num_steps = int(integration_time / dt)

model = js.model.JaxSimModel.build_from_model_description(
    model_description=model_urdf_string, is_urdf=True
)
data = js.data.JaxSimModelData.build(model=model)
integrator = integrators.fixed_step.RungeKutta4SO3.build(
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model,
        data=data,
        system_dynamics=js.ode.system_dynamics,
    ),
)
integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)
jaxsim[3372] INFO Combining the pose of base link 'rail' with the pose of joint 'world_to_rail'
jaxsim[3372] INFO The kinematic graph doesn't need to be reduced

Let’s reset the cartpole to a random state.

random_positions = jax.random.uniform(
    minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)
)

data = data.reset_joint_positions(positions=random_positions)

The visualization is done using mujoco package, to be able to render easily the animations also on Google Colab. If you are not interested in the animation, execute but do not try to understand deeply this cell.

# @title Set up MuJoCo renderer
os.environ["MUJOCO_GL"] = "osmesa"

from jaxsim.mujoco import MujocoModelHelper, MujocoVideoRecorder
from jaxsim.mujoco.loaders import UrdfToMjcf, MujocoCamera

mjcf_string, assets = UrdfToMjcf.convert(
    urdf=model.built_from,
    cameras=MujocoCamera.build_from_target_view(
        camera_name="cartpole_camera",
        lookat=jnp.array([0.0, data.joint_positions()[0], 1.2]),
        distance=3,
        azimut=150,
        elevation=-10,
    ),
)
mj_model_helper = MujocoModelHelper.build_from_xml(
    mjcf_description=mjcf_string, assets=assets
)

# Create the video recorder.
recorder = MujocoVideoRecorder(
    model=mj_model_helper.model,
    data=mj_model_helper.data,
    fps=int(1 / 0.010),
    width=320 * 2,
    height=240 * 2,
)

Let’s see how the model behaves when not controlled:

import mediapy as media

for _ in range(num_steps):
    data, integrator_state = js.model.step(
        dt=dt,
        model=model,
        data=data,
        integrator=integrator,
        integrator_state=integrator_state,
        joint_forces=None,
        link_forces=None,
    )

    mj_model_helper.set_joint_positions(
        positions=data.joint_positions(), joint_names=model.joint_names()
    )

    recorder.record_frame(camera_name="cartpole_camera")

media.show_video(recorder.frames, fps=1 / dt)
recorder.frames = []

Let’s now define the PD controller. We will use the following equations:

(1)#\[\begin{align} \mathbf{M}\ddot{s} + \underbrace{\mathbf{C}\dot{s} + \mathbf{G}}_{\mathbf{H}} = \tau \\ \tau = \mathbf{H} - \mathbf{K}_p(s - s_d) - \mathbf{K}_d(\dot{s} - \dot{s}_d) \end{align}\]

where \(\mathbf{M}\) is the mass matrix, \(\mathbf{C}\) is the Coriolis matrix, \(\mathbf{G}\) is the gravity vector, \(\mathbf{K}_p\) is the proportional gain matrix, \(\mathbf{K}_d\) is the derivative gain matrix, \(s\) is the position vector, \(\dot{s}\) is the velocity vector, \(\ddot{s}\) is the acceleration vector, and \(s_d\) and \(\dot{s}_d\) are the desired position and velocity vectors, respectively.

# Define the PD gains
KP = 10.0
KD = 6.0


def pd_controller(
    data: js.data.JaxSimModelData, q_d: jax.Array, q_dot_d: jax.Array
) -> jax.Array:

    # Compute the gravity compensation term
    H = js.model.free_floating_bias_forces(model=model, data=data)[6:]

    q = data.joint_positions()
    q_dot = data.joint_velocities()

    return H + KP * (q_d - q) + KD * (q_dot_d - q_dot)

Now, we can use the pd_controller function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position q_d to 0 and the desired velocity q_dot_d to 0.

for _ in range(num_steps):
    control_torques = pd_controller(
        data=data,
        q_d=jnp.array([0.0, 0.0]),
        q_dot_d=jnp.array([0.0, 0.0]),
    )

    data, integrator_state = js.model.step(
        dt=dt,
        model=model,
        data=data,
        integrator=integrator,
        integrator_state=integrator_state,
        joint_forces=control_torques,
        link_forces=None,
    )

    mj_model_helper.set_joint_positions(
        positions=data.joint_positions(), joint_names=model.joint_names()
    )

    recorder.record_frame(camera_name="cartpole_camera")

media.show_video(recorder.frames, fps=1 / dt)
recorder.frames = []