Try reinforcement learning with equinox

ja
RL
deep
Published

July 23, 2023

I recently tried equinox, a jax-based library for defining and managing neural nets, and I liked it. So in this blog post I will introduce equinox and demonstrate its use with a reinforcement learning example. Other jax-based NN libraries include haiku from Deepmind and flax from Google Research. Both are actually quite similar because their designs are based on stax, which is a reference implementation by the Jax developers. In other words, I would say that both haiku and flax are libraries that have added a PyTorch-like Module to stax. An equinox document briefly summarizes the approach of haiku and flax, calling it the ‘init-apply approach’. Let’s start this blog post by paraphrasing that document.

What jax can do for you, and what it doesn’t provide

So, what can jax do? Its webpage says that:

JAX is Autograd and XLA, brought together for high-performance numerical computing.

XLA is an intermediate language for deep learning developed as a backend to Tensorflow. It’s intended for producing optimized programs for vector accelartors like GPU, TPU, and even modern CPU’s SIMD instructions. Jax provides a NumPy-like library called jax.numpy as a front-end to XLA, which allows NumPy-like Python code to be vector-parallelized into fast GPU code by runtime (just-in-time) compilation. So what is Autograd? This is also the name of a library previously developed by the developers of jax, but you can think of it as autodifferentiation in general. Autodifferentiation is a technique to automatically derive gradients of functions in numerical computation. Jax supports autograd, which means that you can derive the partial derivative of the function w.r.t. your inputs, without explicitly write the code to do that. Let’s try to compute the partial derivative of \(f(x, y) = x^2 + y\):

Code
import jax
import jax.numpy as jnp

def f(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.sum(x ** 2 + y)

x = jnp.array([3, 4, 5, 6, 7], dtype=jnp.float32)
y = jnp.array([5, 2, 5, 7, 2], dtype=jnp.float32)
jax.grad(f, argnums=(0, 1))(x, y)
(Array([ 6.,  8., 10., 12., 14.], dtype=float32),
 Array([1., 1., 1., 1., 1.], dtype=float32))

Since \(\frac{\partial}{\partial x}f = 2x, \frac{\partial}{\partial y}f = 1\), the outputs look correct. Since jax automatically computes partial derivatives, you can use it directly for gradient descent to train your NN model. In addition, if you put the computation of taking this gradient and updating the model into jax.jit, it can acceralate the whole model traning procedure. The problem here is that there are too many parameters in the case of deep neural networks, so it would be too much work to write a program like this, assigning variables for every single parameter. This is a surprisingly serious problem because recent deep NN often have too many parameters, difficult to manage by hand. So I want a system to efficiently manage the all parameters of my model with covenient abstractions.

NN parameter management by init-apply approach

In stax, haiku, and flax, this problem of parameter management is solved by the so-called “init-apply approach”, which can be summarized as:

  1. the model has no parameters
  2. instead, the model is just a set of two functions: init and apply
  • init takes an example input, initializes the parameters, and returns the initial parameter
  • apply takes input and parameters, and then compute the output of the model

So the APIs for flax and haiku are as follows. While there are a number of differences (e.g., flax uses __call__ (while haiku uses forward like PyTorch (just like how Chainer and PyTorch differ), flax employs dataclasses.dataclass for Module class), their design is quite similar.

class Linear(Module):
    def __call__(self, x):
        batch_size = x.shape[0]
        w = self.parameter(output, shape=(batch_size, self.output_size) init=self.w_init)
        b = self.parameter(output, shape=(1, self.output_size) init=self.w_init)
        return w @ x.T + b

model = Linear(output_size=10)
params = model.init(jnp.zeros(1, 100))
result = model.apply(params, jnp.zeros(10, 100))

In the above example, I used Module to represent the abstract class for defining NN. Note that he method self.parameter in the above example is the counterpart of the parameter registration functions in flax and haiku. In the case of haiku, the params is the following dict.

{
    "Linear/~/w": [...],
    "Linear/~/b": [...],
}

So, it means that users need to know this naming scheme to access a specific value of parameter. stax doesn’t have Module feature. Instead, it provides function combinator to combine multiple init and apply. Then let’s summarize the pros and cons of this approach.

Pros

  • No need to specify the shape of the input Array when initializing the Module
    • Parameters are initialized when init is called
  • Functions and data are separated
    • Module has no variables that can be changed, and is treated completely separately from its parameters

Disadvantages.

  • (Only haiku/flax) Module is redundant
    • Module only has “model settings” such as output dimensions
  • Direct access to parameter elements is cumbersome
    • For example, in haiku, each parameter element can be accessed with params["Linear/~/w"], so we need to understand this naming scheme for keys to check parameters
  • Not very object oriented
    • Because classes don’t have data
  • (Only haiku/flax) Parameter calls in Module need to be converted to functions that can be used by jax.grad.
    • For example, haiku.transform converts functions that include parameter calls with haiku.get_parameter into functions that take parameters as arguments

equinox Features

Equinox is geared toward “more object-oriented” (or PyTorch-like) interface, compared to the init-apply approach. Take a look at Quick example on the top page of the documentation.

Code
import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias
    
@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)
print("Model weight?", model.weight)
print("Grad?", grads)
Model weight? [[0.59902626 0.2172144 ]
 [0.660603   0.03266738]
 [1.2164948  1.1940813 ]]
Grad? Linear(weight=f32[3,2], bias=f32[3])

Here, the most significant feature of equinox is that the Module directly has parameters, as show in:

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

Compared to the init-apply approach, what pros and cons does this approach have?

Pros

  • Easy to understand
  • Easy to debug
    • In flax and haiku, we need to transform the module for debugging
    • We can write test codes in a few lines, like net = Module(...); net(input)
  • Easy to manually manipulate parameters

Cons

  • Modules require input dimensions when initialization
  • Modules need to be separtated into NN parameters and other constants before applyinggrad, jit, and scan or some other jax decorators

What does the last sentence Modules need to be separtated into NN parameters and other constants mean? Let’s consider Self-Attention example below:

Code
class SelfAttention(eqx.Module):
    q: eqx.nn.Linear
    k: eqx.nn.Linear
    v: eqx.nn.Linear
    sqrt_d_attn: float

    def __init__(self, d_in: int, d_attn: int, key: jax.Array) -> None:
        q_key, k_key, v_key = jax.random.split(key, 3)
        self.q = eqx.nn.Linear(d_in, d_attn, key=q_key)
        self.k = eqx.nn.Linear(d_in, d_attn, key=k_key)
        self.v = eqx.nn.Linear(d_in, d_attn, key=v_key)
        self.sqrt_d_attn = float(jnp.sqrt(d_attn))

    def __call__(self, e: jax.Array) -> jax.Array:
        q = jax.vmap(self.q)(e)
        k = jax.vmap(self.k)(e)
        alpha = jax.nn.softmax(q.T @ k / self.sqrt_d_attn, axis=-1)
        return jax.vmap(self.v)(e) @ alpha.T

While initialization methods such as __init__ are prepared in advance since eqx.Module uses dataclasses.dataclass internally, but let’s override them to initialize parameters. Since \(\sqrt{d_\mathrm{attn}}\) is a constant, let’s make it a member variable as well. Let’s compute gradients:

Code
model = SelfAttention(4, 8, jax.random.PRNGKey(10))
jax.grad(lambda model, x: jnp.mean(model(x)))(model, jnp.ones((3, 4)))
SelfAttention(
  q=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  k=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  v=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  sqrt_d_attn=f32[]
)

The trouble here is that the partial derivative in sqrt_d_attn is also computed. The fact that eqx.Module itself has a parameter causes partial derivatives to be computed even for member variables that would otherwise be constants. This problem is solved in equinox by using eqx.partition and eqx.is_inexact_array to separate a module instance to a “32-bit floating-point jax. Array or numpy.ndarray” and “other member variables”. Let’s try it.

Code
eqx.partition(model, eqx.is_inexact_array)
(SelfAttention(
   q=Linear(
     weight=f32[8,4],
     bias=f32[8],
     in_features=4,
     out_features=8,
     use_bias=True
   ),
   k=Linear(
     weight=f32[8,4],
     bias=f32[8],
     in_features=4,
     out_features=8,
     use_bias=True
   ),
   v=Linear(
     weight=f32[8,4],
     bias=f32[8],
     in_features=4,
     out_features=8,
     use_bias=True
   ),
   sqrt_d_attn=None
 ),
 SelfAttention(
   q=Linear(weight=None, bias=None, in_features=4, out_features=8, use_bias=True),
   k=Linear(weight=None, bias=None, in_features=4, out_features=8, use_bias=True),
   v=Linear(weight=None, bias=None, in_features=4, out_features=8, use_bias=True),
   sqrt_d_attn=2.8284270763397217
 ))

We can see the struct is divided into one with sqrt_d_attn=None and one withsqrt_d_attn=2.8284270763397217. To summarize, when we want to compute the gradients of w.r.t. all parameters in the module with non-parameter member variables,

  1. Split Module into parameters and others.
  2. Creaate a function f(module, ...) of which you want the gradient.
  3. Create a function that wraps f with separated arguments for paremters and non-parameters. Say, wrapped_f(params, others, ...)
  4. compute the gradient with jax.grad(wrapped_f)(params, others, ...)

It’s tiring. But a good news is that equinox.filter_grad can do all of this tiring procedure for us. You can use it for most cases. Let’s try it.

Code
eqx.filter_grad(lambda model, x: jnp.mean(model(x)))(model, jnp.ones((3, 4)))
SelfAttention(
  q=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  k=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  v=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  sqrt_d_attn=None
)

It returned the gradient correctly. The constant sqrt_d_attn is exactly masked by None. So, if you want to have a constant in a Module, just use a type other than jax.Array or ndarray for now. If you want to have a constant as jax.Array, you can’t filter_value_and_grad directly and you need to manually decompose the struct. So I recommend to just have constants as Python built-in types such as float. If you use jax.jit, they are just compiled to a part of binary code, so it adds no overhead.

Try reinforcement learning

Environment

Now that we have introduced the features of equinox, let’s try reinforcement learning with it. Since we are using jax and since the API of Gym has changed a lot and changed to gymnasium we can’t keep up with it at all, let’s use an environment made by jax. Here I use Maze from the jumanji.

Code
import jumanji
from jumanji.wrappers import AutoResetWrapper
from IPython.display import HTML

env = jumanji.make("Maze-v0")
env = AutoResetWrapper(env)
n_actions = env.action_spec().num_values
key, *keys = jax.random.split(jax.random.PRNGKey(20230720), 11)
state, timestep = env.reset(key)
states = [state]
for key in keys:
    action = jax.random.randint(key=key, minval=0, maxval=n_actions, shape=())
    state, timestep = env.step(state, action)
    states.append(state)
anim = env.animate(states)
HTML(anim.to_html5_video().replace('="1000"', '="640"'))  # Change video size

It looks easy to use. However, the default size of the video was too big, so I replaced the HTML tags to make it smaller. It seems to be usable in combination with vmap and jit for actual learning.

Let’s implement PPO.

So, let’s try to learn this environment. Here we will implement PPO, because it’s fast.

Input

Input as an array with 3x10x10 shape, each representing a 10x10 binary image of the position of the wall, the position of the agent, and the position of the goal, respectively.

Code
from jumanji.environments.routing.maze.types import Observation, State

def obs_to_image(obs: Observation) -> jax.Array:
    walls = obs.walls.astype(jnp.float32)
    agent = jnp.zeros_like(walls).at[obs.agent_position].set(1.0)
    target = jnp.zeros_like(walls).at[obs.target_position].set(1.0)
    return jnp.stack([walls, agent, target])

Network

I use a simple network with a 2D convolution layer, followed by ReLU activations and two linear layers. Policy and value networks share the first two layers, and the policy is modeled by a categorical distribution. To make things simpler, I will write the input and output sizes as \(3 \times 10 \times 10\) and \(4\) for the number of actions.

Code
from typing import NamedTuple

from jax.nn.initializers import orthogonal


class PPONetOutput(NamedTuple):
    policy_logits: jax.Array
    value: jax.Array


class SoftmaxPPONet(eqx.Module):
    torso: list
    value_head: eqx.nn.Linear
    policy_head: eqx.nn.Linear

    def __init__(self, key: jax.Array) -> None:
        key1, key2, key3, key4, key5 = jax.random.split(key, 5)
        # Common layers
        self.torso = [
            eqx.nn.Conv2d(3, 1, kernel_size=3, key=key1),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(64, 64, key=key2),
            jax.nn.relu,
        ]
        self.value_head = eqx.nn.Linear(64, 1, key=key3)
        policy_head = eqx.nn.Linear(64, 4, key=key4)
        # Use small value for policy initialization
        self.policy_head = eqx.tree_at(
            lambda linear: linear.weight,
            policy_head,
            orthogonal(scale=0.01)(key5, policy_head.weight.shape),
        )

    def __call__(self, x: jax.Array) -> PPONetOutput:
        for layer in self.torso:
            x = layer(x)
        value = self.value_head(x)
        policy_logits = self.policy_head(x)
        return PPONetOutput(policy_logits=policy_logits, value=value)

    def value(self, x: jax.Array) -> jax.Array:
        for layer in self.torso:
            x = layer(x)
        return self.value_head(x)

Rollout

In PPO implementations, it is common to collect a history of environmental interaction about 500~8000 steps and use it to update the network several times. Here we use jax.lax.scan to implement a rollout that is faster than native Python for loop. While scan is super fast, its usage requires a bit of care. In particular, if you set the second output of the function that moves forward one step to results: list[Result], keep in mind that the final return will be Result(member1=stack([m1 for m1 in results.member1]), ...) Also, since the argument of exec_rollout contains a SoftmaxPPONet which is an instance of eqx.Module, you can’t use jax.jit directly. Instead, you need to use eqx.filter_jit that decomposes the module to parameters and non-parameters automatically. Actions are sampled by a categorical distribution obtained by applying softmax to the output of the policy network, but since Observation contains an action_mask that tells you the direction in which you can move in the maze, you can mask the actions you cannot take with this by applying -inf to policy logits. In a simple environment, you don’t need to do this, but mazes with many walls are difficult to solve without this masking.

Code
import chex


@chex.dataclass
class Rollout:
    """Rollout buffer that stores the entire history of one rollout"""

    observations: jax.Array
    actions: jax.Array
    action_masks: jax.Array
    rewards: jax.Array
    terminations: jax.Array
    values: jax.Array
    policy_logits: jax.Array


def mask_logits(policy_logits: jax.Array, action_mask: jax.Array) -> jax.Array:
    return jax.lax.select(
        action_mask,
        policy_logits,
        jnp.ones_like(policy_logits) * -jnp.inf,
    )


vmapped_obs2i = jax.vmap(obs_to_image)


@eqx.filter_jit
def exec_rollout(
    initial_state: State,
    initial_obs: Observation,
    env: jumanji.Environment,
    network: SoftmaxPPONet,
    prng_key: jax.Array,
    n_rollout_steps: int,
) -> tuple[State, Rollout, Observation, jax.Array]:
    def step_rollout(
        carried: tuple[State, Observation],
        key: jax.Array,
    ) -> tuple[tuple[State, jax.Array], Rollout]:
        state_t, obs_t = carried
        obs_image = vmapped_obs2i(obs_t)
        net_out = jax.vmap(network)(obs_image)
        masked_logits = mask_logits(net_out.policy_logits, obs_t.action_mask)
        actions = jax.random.categorical(key, masked_logits, axis=-1)
        state_t1, timestep = jax.vmap(env.step)(state_t, actions)
        rollout = Rollout(
            observations=obs_image,
            actions=actions,
            action_masks=obs_t.action_mask,
            rewards=timestep.reward,
            terminations=1.0 - timestep.discount,
            values=net_out.value,
            policy_logits=masked_logits,
        )
        return (state_t1, timestep.observation), rollout

    (state, obs), rollout = jax.lax.scan(
        step_rollout,
        (initial_state, initial_obs),
        jax.random.split(prng_key, n_rollout_steps),
    )
    next_value = jax.vmap(network.value)(vmapped_obs2i(obs))
    return state, rollout, obs, next_value

Let’s test it. The advantage of the jax environment is that you can easily vector-parallelize the environment with jax.vmap, so this time we will run it in 16 parallel. If you apply vmap to reset and give it 16 PRNGKeys, you will get 16 parallel State.

Code
key, net_key, reset_key, rollout_key = jax.random.split(key, 4)
pponet = SoftmaxPPONet(net_key)
initial_state, initial_timestep = jax.vmap(env.reset)(jax.random.split(reset_key, 16))
next_state, rollout, next_obs, next_value = exec_rollout(
    initial_state,
    initial_timestep.observation,
    env,
    pponet,
    rollout_key,
    512,
)
rollout.rewards.shape
(512, 16)

We can confirm that the inputs are 16 in parallel, and each member of Rollout has n. steps x n. env x .... shape.

Learning

Now that the data has been collected, it is time to write the code to update the network. First, compute GAE. Since it is unexpectedly a bottleneck, let’s speed it up with fori_loop.

Code
@chex.dataclass(frozen=True, mappable_dataclass=False)
class Batch:
    """Batch for PPO, indexable to get a minibatch."""

    observations: jax.Array
    action_masks: jax.Array
    onehot_actions: jax.Array
    rewards: jax.Array
    advantages: jax.Array
    value_targets: jax.Array
    log_action_probs: jax.Array

    def __getitem__(self, idx: jax.Array):
        return self.__class__(  # type: ignore
            observations=self.observations[idx],
            action_masks=self.action_masks[idx],
            onehot_actions=self.onehot_actions[idx],
            rewards=self.rewards[idx],
            advantages=self.advantages[idx],
            value_targets=self.value_targets[idx],
            log_action_probs=self.log_action_probs[idx],
        )


def compute_gae(
    r_t: jax.Array,
    discount_t: jax.Array,
    values: jax.Array,
    lambda_: float = 0.95,
) -> jax.Array:
    """Efficiently compute generalized advantage estimator (GAE)"""

    gamma_lambda_t = discount_t * lambda_
    delta_t = r_t + discount_t * values[1:] - values[:-1]
    n = delta_t.shape[0]

    def update(i: int, advantage_t: jax.Array) -> jax.Array:
        t = n - i - 1
        adv_t = delta_t[t] + gamma_lambda_t[t] * advantage_t[t + 1]
        return advantage_t.at[t].set(adv_t)

    advantage_t = jax.lax.fori_loop(0, n, update, jnp.zeros_like(values))
    return advantage_t[:-1]


@eqx.filter_jit
def make_batch(
    rollout: Rollout,
    next_value: jax.Array,
    gamma: float,
    gae_lambda: float,
) -> Batch:
    all_values = jnp.concatenate(
        [jnp.squeeze(rollout.values), next_value.reshape(1, -1)]
    )
    advantages = compute_gae(
        rollout.rewards,
        # Set γ = 0 when the episode terminates
        (1.0 - rollout.terminations) * gamma,
        all_values,
        gae_lambda,
    )
    value_targets = advantages + all_values[:-1]
    onehot_actions = jax.nn.one_hot(rollout.actions, 4)
    _, _, *obs_shape = rollout.observations.shape
    log_action_probs = jnp.sum(
        jax.nn.log_softmax(rollout.policy_logits) * onehot_actions,
        axis=-1,
    )
    return Batch(
        observations=rollout.observations.reshape(-1, *obs_shape),
        action_masks=rollout.action_masks.reshape(-1, 4),
        onehot_actions=onehot_actions.reshape(-1, 4),
        rewards=rollout.rewards.ravel(),
        advantages=advantages.ravel(),
        value_targets=value_targets.ravel(),
        log_action_probs=log_action_probs.ravel(),
    )
Code
batch = make_batch(rollout, next_value, 0.99, 0.95)
batch.advantages.shape, batch.onehot_actions.shape, batch.log_action_probs.shape
((8192,), (8192, 4), (8192,))

Looks OK. Now we can sample a mini-batch from the Batch we created with this and update it with gradient descent to minimize the loss function. Here I use optax, which is a standard in the jax community. At this point, note the following three points.

  • Use eqx.filter_grad instead of jax.grad as explained in the previous sections
  • SoftmaxPPONet has types that cannot be used as jax.jit arguments, such as jax.nn.relu, so use eqx.partition to break them up before using them as jax.lax.scan arguments
  • Similarly, SoftmaxPPONet cannot be used as an argument to optax initialization/update functions as is, so it should be broken down with eqx.partition or eqx.filter to exclude members other than jax.Array.

Also, if you want to use jax.lax.scan in a mini-batch update loop, there are several ways to do it. What I do here is that, assuming mini-batch size \(N\), number of updates \(M\), number of update epochs \(K\) times, overall batch size \(N \times M\),

  1. Make permutations of \(0, 1, 2, ... , NM - 1\) \(K\) times.
  2. Make \(K\) copies of Batch shuffled by the permutations we made
  3. Concatenate each member with jnp.concatenate and reshape to an array with size \(MK \times N \times ...\).

It’s a bit difficult, and you may not need to speed it up this fast. If you are worried about memory usage, you could write a \(K\) loop in Python, but in this case, the input is \(3\times 10\times 10\), so it seems to be ok.

Code
import optax


def loss_function(
    network: SoftmaxPPONet,
    batch: Batch,
    ppo_clip_eps: float,
) -> jax.Array:
    net_out = jax.vmap(network)(batch.observations)
    # Policy loss
    log_pi = jax.nn.log_softmax(
        jax.lax.select(
            batch.action_masks,
            net_out.policy_logits,
            jnp.ones_like(net_out.policy_logits * -jnp.inf),
        )
    )
    log_action_probs = jnp.sum(log_pi * batch.onehot_actions, axis=-1)
    policy_ratio = jnp.exp(log_action_probs - batch.log_action_probs)
    clipped_ratio = jnp.clip(policy_ratio, 1.0 - ppo_clip_eps, 1.0 + ppo_clip_eps)
    clipped_objective = jnp.fmin(
        policy_ratio * batch.advantages,
        clipped_ratio * batch.advantages,
    )
    policy_loss = -jnp.mean(clipped_objective)
    # Value loss
    value_loss = jnp.mean(0.5 * (net_out.value - batch.value_targets) ** 2)
    # Entropy regularization
    entropy = jnp.mean(-jnp.exp(log_pi) * log_pi)
    return policy_loss + value_loss - 0.01 * entropy


vmapped_permutation = jax.vmap(jax.random.permutation, in_axes=(0, None), out_axes=0)


@eqx.filter_jit
def update_network(
    batch: Batch,
    network: SoftmaxPPONet,
    optax_update: optax.TransformUpdateFn,
    opt_state: optax.OptState,
    prng_key: jax.Array,
    minibatch_size: int,
    n_epochs: int,
    ppo_clip_eps: float,
) -> tuple[optax.OptState, SoftmaxPPONet]:
    # Prepare update function
    dynamic_net, static_net = eqx.partition(network, eqx.is_array)

    def update_once(
        carried: tuple[optax.OptState, SoftmaxPPONet],
        batch: Batch,
    ) -> tuple[tuple[optax.OptState, SoftmaxPPONet], None]:
        opt_state, dynamic_net = carried
        network = eqx.combine(dynamic_net, static_net)
        grad = eqx.filter_grad(loss_function)(network, batch, ppo_clip_eps)
        updates, new_opt_state = optax_update(grad, opt_state)
        dynamic_net = optax.apply_updates(dynamic_net, updates)
        return (new_opt_state, dynamic_net), None

    # Prepare minibatches
    batch_size = batch.observations.shape[0]
    permutations = vmapped_permutation(jax.random.split(prng_key, n_epochs), batch_size)
    minibatches = jax.tree_map(
        # Here, x's shape is [batch_size, ...]
        lambda x: x[permutations].reshape(-1, minibatch_size, *x.shape[1:]),
        batch,
    )
    # Update network n_epochs x n_minibatches times
    (opt_state, updated_dynet), _ = jax.lax.scan(
        update_once,
        (opt_state, dynamic_net),
        minibatches,
    )
    return opt_state, eqx.combine(updated_dynet, static_net)

So now that we have all the components, let’s start learning. First, let’s try a simple maze with no walls. I copied and pasted junmanji’s default maze generator and rewrote it. In this environment, the reward is given only at the goal, so the average return per episode is simply calculated by \(\frac{\sum R}{\mathrm{Num.~episodes}}\), which is used as a progress indicator. I selected each hyperparameter from my experience.

Code
from jumanji.environments.routing.maze.generator import Generator
from jumanji.environments.routing.maze.types import Position, State


class TestGenerator(Generator):
    def __init__(self) -> None:
        super().__init__(num_rows=10, num_cols=10)

    def __call__(self, key: chex.PRNGKey) -> State:
        walls = jnp.zeros((10, 10), dtype=bool)
        agent_position = Position(row=0, col=0)
        target_position = Position(row=9, col=9)

        # Build the state.
        return State(
            agent_position=agent_position,
            target_position=target_position,
            walls=walls,
            action_mask=None,
            key=key,
            step_count=jnp.array(0, jnp.int32),
        )
Code
def run_training(
    key: jax.Array,
    adam_lr: float = 3e-4,
    adam_eps: float = 1e-7,
    gamma: float = 0.99,
    gae_lambda: float = 0.95,
    n_optim_epochs: int = 10,
    minibatch_size: int = 1024,
    n_agents: int = 16,
    n_rollout_steps: int = 512,
    n_total_steps: int = 16 * 512 * 100,
    ppo_clip_eps: float = 0.2,
    **env_kwargs,
) -> SoftmaxPPONet:
    key, net_key, reset_key = jax.random.split(key, 3)
    pponet = SoftmaxPPONet(net_key)
    env = AutoResetWrapper(jumanji.make("Maze-v0", **env_kwargs))
    adam_init, adam_update = optax.adam(adam_lr, eps=adam_eps)
    opt_state = adam_init(eqx.filter(pponet, eqx.is_array))
    env_state, timestep = jax.vmap(env.reset)(jax.random.split(reset_key, 16))
    obs = timestep.observation

    n_loop = n_total_steps // (n_agents * n_rollout_steps)
    return_reporting_interval = 1 if n_loop < 10 else n_loop // 10
    n_episodes, reward_sum = 0.0, 0.0
    for i in range(n_loop):
        key, rollout_key, update_key = jax.random.split(key, 3)
        env_state, rollout, obs, next_value = exec_rollout(
            env_state,
            obs,
            env,
            pponet,
            rollout_key,
            n_rollout_steps,
        )
        batch = make_batch(rollout, next_value, gamma, gae_lambda)
        opt_state, pponet = update_network(
            batch,
            pponet,
            adam_update,
            opt_state,
            update_key,
            minibatch_size,
            n_optim_epochs,
            ppo_clip_eps,
        )
        n_episodes += jnp.sum(rollout.terminations).item()
        reward_sum += jnp.sum(rollout.rewards).item()
        if i > 0 and (i % return_reporting_interval == 0):
            print(f"Mean episodic return: {reward_sum / n_episodes}")
            n_episodes = 0.0
            reward_sum = 0.0
    return pponet
Code
import datetime

started = datetime.datetime.now()
key, training_key = jax.random.split(key)
trained_net = run_training(training_key, n_total_steps=16 * 512 * 10, generator=TestGenerator())
elapsed = datetime.datetime.now() - started
print(f"Elapsed time: {elapsed.total_seconds():.2}s")
Mean episodic return: 0.40782122905027934
Mean episodic return: 0.967741935483871
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Elapsed time: 3.2s

About 80,000 steps were learned in a little over 3 seconds. That’s fast. Let’s see what the trained agent looks like.

Code
@eqx.filter_jit
def visualization_rollout(
    key: jax.random.PRNGKey,
    pponet: SoftmaxPPONet,
    env: jumanji.Environment,
    n_steps: int,
) -> list[State]:
    def step_rollout(
        carried: tuple[State, Observation],
        key: jax.Array,
    ) -> tuple[tuple[State, jax.Array], State]:
        state_t, obs_t = carried
        obs_image = obs_to_image(obs_t)
        net_out = pponet(obs_image)
        action, _ = sample_action(key, net_out.policy_logits, obs_t.action_mask)
        state_t1, timestep = env.step(state_t, action)
        return (state_t1, timestep.observation), state_t1

    initial_state, timestep = env.reset(key)
    _, states = jax.lax.scan(
        step_rollout,
        (initial_state, timestep.observation),
        jax.random.split(key, n_steps),
    )
    leaves, treedef = jax.tree_util.tree_flatten(states)
    return [initial_state] + [treedef.unflatten(leaf) for leaf in zip(*leaves)]
Code
env = AutoResetWrapper(jumanji.make("Maze-v0", generator=TestGenerator()))
key, eval_key = jax.random.split(key)
states = visualization_rollout(eval_key, trained_net, env, 40)
anim = env.animate(states)
HTML(anim.to_html5_video().replace('="1000"', '="640"'))

As expected, this one looks easy enough. Next time, we would like to make it a little more difficult. The default maze, which is generated completely at random, is very slow, so let’s make our own maze. The maze samples the start and the goal from 3 different locations with the same probability, and solves a total of 9 combinations. Let’s try to train the maze with 800,000 steps, which is 10 times as many as the default maze.

Code
class MedDifficultyGenerator(Generator):
    WALLS = [
        [0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
        [0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 0, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 1, 0, 1, 1, 0, 0, 0, 0],
        [1, 0, 1, 0, 1, 1, 0, 1, 1, 1],
        [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
        [1, 1, 1, 0, 0, 1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    ]
    def __init__(self) -> None:
        super().__init__(num_rows=10, num_cols=10)

    def __call__(self, key: chex.PRNGKey) -> State:
        key, config_key = jax.random.split(key)
        walls = jnp.array(self.WALLS).astype(bool)
        agent_cfg, target_cfg = jax.random.randint(config_key, (2,), 0, 2)
        agent_position = jax.lax.switch(
            agent_cfg,
            [
                lambda: Position(row=0, col=0),
                lambda: Position(row=9, col=0),
                lambda: Position(row=0, col=9),
            ]
        )
        target_position = jax.lax.switch(
            target_cfg,
            [
                lambda: Position(row=3, col=9),
                lambda: Position(row=7, col=8),
                lambda: Position(row=7, col=0),
            ]
        )
        # Build the state.
        return State(
            agent_position=agent_position,
            target_position=target_position,
            walls=walls,
            action_mask=None,
            key=key,
            step_count=jnp.array(0, jnp.int32),
        )
Code
import datetime

started = datetime.datetime.now()
key, training_key = jax.random.split(key)
trained_net = run_training(
    training_key,
    n_total_steps=16 * 512 * 100,
    generator=MedDifficultyGenerator(),
)
elapsed = datetime.datetime.now() - started
print(f"Elapsed time: {elapsed.total_seconds():.2}s")
Mean episodic return: 0.2812202097235462
Mean episodic return: 0.4077407740774077
Mean episodic return: 0.5992010652463382
Mean episodic return: 0.7285136501516684
Mean episodic return: 0.75
Mean episodic return: 0.8620564808110065
Mean episodic return: 0.9868290258449304
Mean episodic return: 0.9977788746298124
Mean episodic return: 0.9970540974825924
Elapsed time: 1.2e+01s
Code
env = AutoResetWrapper(jumanji.make("Maze-v0", generator=MedDifficultyGenerator()))
key, eval_key = jax.random.split(key)
states = visualization_rollout(eval_key, trained_net, env, 100)
anim = env.animate(states)
HTML(anim.to_html5_video().replace('="1000"', '="640"'))

There is some hesitation, but it looks like we are getting to the goal.

Summary

In this blog, I introduced equinox, a library for handling neural networks with jax, and showed an example of a fast PPO implementation using equinox. Through writing this post, I found eqx.filter_jit very useful. Since I already get used to it, I feel a bit more frastration to write @partial(jax.jit, static_argnums=(2, 3)) every time. However, if you want to use jax.lax.scan as an argument, you need to split it manually with eqx.partition, which was a bit troublesome.