A Gentle Introduction to Deep Reinforcement Learning in JAX | by Ryan Pégoud | Nov, 2023

Editor
9 Min Read


Implementing the update function for DQN is slightly more complex, let’s break it down:

  • First, the _loss_fn function implements the squared error described previously for a single experience.
  • Then, _batch_loss_fn acts as a wrapper for _loss_fn and decorates it with vmap, applying the loss function to a batch of experiences. We then return the average error for this batch.
  • Finally, update acts as a final layer to our loss function, computing its gradient with respect to the online network parameters, the target network parameters, and a batch of experiences. We then use Optax (a JAX library commonly used for optimization) to perform an optimizer step and update the online parameters.

Notice that, similarly to the replay buffer, the model and optimizer are pure functions modifying an external state. The following line serves as a good illustration of this principle:

updates, optimizer_state = optimizer.update(grads, optimizer_state)

This also explains why we can use a single model for both the online and target networks, as the parameters are stored and updated externally.

# target network predictions
self.model.apply(target_net_params, None, state)
# online network predictions
self.model.apply(online_net_params, None, state)

For context, the model we use in this article is a multi-layer perceptron defined as follows:

N_ACTIONS = 2
NEURONS_PER_LAYER = [64, 64, 64, N_ACTIONS]
online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)

@hk.transform
def model(x):
# simple multi-layer perceptron
mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
return mlp(x)

online_net_params = model.init(online_key, jnp.zeros((STATE_SHAPE,)))
target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))

prediction = model.apply(online_net_params, None, state)

Replay Buffer

Now let us take a step back and look closer at replay buffers. They are widely used in reinforcement learning for a variety of reasons:

  • Generalization: By sampling from the replay buffer, we break the correlation between consecutive experiences by mixing up their order. This way, we avoid overfitting to specific sequences of experiences.
  • Diversity: As the sampling is not limited to recent experiences, we generally observe a lower variance in updates and prevent overfitting to the latest experiences.
  • Increased sample efficiency: Each experience can be sampled multiple times from the buffer, enabling the model to learn more from individual experiences.

Finally, we can use several sampling schemes for our replay buffer:

  • Uniform sampling: Experiences are sampled uniformly at random. This type of sampling is straightforward to implement and allows the model to learn from experiences independently from the timestep they were collected.
  • Prioritized sampling: This category includes different algorithms such as Prioritized Experience Replay (“PER”, Schaul et al. 2015) or Gradient Experience Replay (“GER”, Lahire et al., 2022). These methods attempt to prioritize the selection of experiences according to some metric related to their “learning potential” (the amplitude of the TD error for PER and the norm of the experience’s gradient for GER).

For the sake of simplicity, we’ll implement a uniform replay buffer in this article. However, I plan to cover prioritized sampling extensively in the future.

As promised, the uniform replay buffer is quite easy to implement, however, there are a few complexities related to the use of JAX and functional programming. As always, we have to work with pure functions that are devoid of side effects. In other words, we are not allowed to define the buffer as a class instance with a variable internal state.

Instead, we initialize a buffer_state dictionary that maps keys to empty arrays with predefined shapes, as JAX requires constant-sized arrays when jit-compiling code to XLA.

buffer_state = 
"states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
"actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
"rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
"next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
"dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),

We will use a UniformReplayBuffer class to interact with the buffer state. This class has two methods:

  • add: Unwraps an experience tuple and maps its components to a specific index. idx = idx % self.buffer_size ensures that when the buffer is full, adding new experiences overwrites older ones.
  • sample: Samples a sequence of random indexes from the uniform random distribution. The sequence length is set by batch_size while the range of the indexes is [0, current_buffer_size-1]. This ensures that we do not sample empty arrays while the buffer is not yet full. Finally, we use JAX’s vmap in combination with tree_map to return a batch of experiences.

Now that our DQN agent is ready for training, we’ll quickly implement a vectorized CartPole environment using the same framework as introduced in an earlier article. CartPole is a control environment having a large continuous observation space, which makes it relevant to test our DQN.

Visual representation of the CartPole Environment (credits and documentation: OpenAI Gymnasium, MIT license)

The process is quite straightforward, we reuse most of OpenAI’s Gymnasium implementation while making sure we use JAX arrays and lax control flow instead of Python or Numpy alternatives, for instance:

# Python implementation
force = self.force_mag if action == 1 else -self.force_mag
# Jax implementation
force = lax.select(jnp.all(action) == 1, self.force_mag, -self.force_mag) )

# Python
costheta, sintheta = math.cos(theta), math.sin(theta)
# Jax
cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)

# Python
if not terminated:
reward = 1.0
...
else:
reward = 0.0
# Jax
reward = jnp.float32(jnp.invert(done))

For the sake of brevity, the full environment code is available here:

The last part of our implementation of DQN is the training loop (also called rollout). As mentioned in previous articles, we have to respect a specific format in order to take advantage of JAX’s speed.

The rollout function might appear daunting at first, but most of its complexity is purely syntactic as we’ve already covered most of the building blocks. Here’s a pseudo-code walkthrough:

1. Initialization:
* Create empty arrays that will store the states, actions, rewards
and done flags for each timestep. Initialize the networks and optimizer
with dummy arrays.
* Wrap all the initialized objects in a val tuple

2. Training loop (repeat for i steps):
* Unpack the val tuple
* (Optional) Decay epsilon using a decay function
* Take an action depending on the state and model parameters
* Perform an environment step and observe the next state, reward
and done flag
* Create an experience tuple (state, action, reward, new_state, done)
and add it to the replay buffer
* Sample a batch of experiences depending on the current buffer size
(i.e. sample only from experiences that have non-zero values)
* Update the model parameters using experience batch
* Every N steps, update the target network's weights
(set target_params = online_params)
* Store the experience's values for the current episode and return
the updated `val` tuple

We can now run DQN for 20,000 steps and observe the performances. After around 45 episodes, the agent manages to obtain decent performances, balancing the pole for more than 100 steps consistently.

The green bars indicate that the agent managed to balance the pole for more than 200 steps, solving the environment. Notably, the agent set its record on the 51st episode, with 393 steps.

Performance report for DQN (made by the author)

The 20.000 training steps were executed in just over a second, at a rate of 15.807 steps per second (on a single CPU)!

These performances hint at JAX’s impressive scaling capabilities, allowing practitioners to run large-scale parallelized experiments with minimal hardware requirements.

Running for 20,000 iterations: 100%|██████████| 20000/20000 [00:01<00:00, 15807.81it/s]

We’ll take a closer look at parallelized rollout procedures to run statistically significant experiments and hyperparameter searches in a future article!

In the meantime, feel free to reproduce the experiment and dabble with hyperparameters using this notebook:

As always, thanks for reading this far! I hope this article provided a decent introduction to Deep RL in JAX. Should you have any questions or feedback related to the content of this article, make sure to let me know, I’m always happy to have a little chat 😉

Until next time 👋

Share this Article
Please enter CoinGecko Free Api Key to get this plugin works.