I recently tried equinox, a jax-based library for defining and managing neural nets, and I liked it. So in this blog post I will introduce equinox and demonstrate its use with a reinforcement learning example. Other jax-based NN libraries include haiku from Deepmind and flax from Google Research. Both are actually quite similar because their designs are based on stax, which is a reference implementation by the Jax developers. In other words, I would say that both haiku and flax are libraries that have added a PyTorch-like Module to stax. An equinox document briefly summarizes the approach of haiku and flax, calling it the ‘init-apply approach’. Let’s start this blog post by paraphrasing that document.
What jax can do for you, and what it doesn’t provide
So, what can jax do? Its webpage says that:
JAX is Autograd and XLA, brought together for high-performance numerical computing.
XLA is an intermediate language for deep learning developed as a backend to Tensorflow. It’s intended for producing optimized programs for vector accelartors like GPU, TPU, and even modern CPU’s SIMD instructions. Jax provides a NumPy-like library called jax.numpy as a front-end to XLA, which allows NumPy-like Python code to be vector-parallelized into fast GPU code by runtime (just-in-time) compilation. So what is Autograd? This is also the name of a library previously developed by the developers of jax, but you can think of it as autodifferentiation in general. Autodifferentiation is a technique to automatically derive gradients of functions in numerical computation. Jax supports autograd, which means that you can derive the partial derivative of the function w.r.t. your inputs, without explicitly write the code to do that. Let’s try to compute the partial derivative of \(f(x, y) = x^2 + y\):
Since \(\frac{\partial}{\partial x}f = 2x, \frac{\partial}{\partial y}f = 1\), the outputs look correct. Since jax automatically computes partial derivatives, you can use it directly for gradient descent to train your NN model. In addition, if you put the computation of taking this gradient and updating the model into jax.jit, it can acceralate the whole model traning procedure. The problem here is that there are too many parameters in the case of deep neural networks, so it would be too much work to write a program like this, assigning variables for every single parameter. This is a surprisingly serious problem because recent deep NN often have too many parameters, difficult to manage by hand. So I want a system to efficiently manage the all parameters of my model with covenient abstractions.
NN parameter management by init-apply approach
In stax, haiku, and flax, this problem of parameter management is solved by the so-called “init-apply approach”, which can be summarized as:
the model has no parameters
instead, the model is just a set of two functions: init and apply
init takes an example input, initializes the parameters, and returns the initial parameter
apply takes input and parameters, and then compute the output of the model
So the APIs for flax and haiku are as follows. While there are a number of differences (e.g., flax uses __call__ (while haiku uses forward like PyTorch (just like how Chainer and PyTorch differ), flax employs dataclasses.dataclass for Module class), their design is quite similar.
class Linear(Module):def__call__(self, x): batch_size = x.shape[0] w =self.parameter(output, shape=(batch_size, self.output_size) init=self.w_init) b =self.parameter(output, shape=(1, self.output_size) init=self.w_init)return w @ x.T + bmodel = Linear(output_size=10)params = model.init(jnp.zeros(1, 100))result = model.apply(params, jnp.zeros(10, 100))
In the above example, I used Module to represent the abstract class for defining NN. Note that he method self.parameter in the above example is the counterpart of the parameter registration functions in flax and haiku. In the case of haiku, the params is the following dict.
{"Linear/~/w": [...],"Linear/~/b": [...],}
So, it means that users need to know this naming scheme to access a specific value of parameter. stax doesn’t have Module feature. Instead, it provides function combinator to combine multiple init and apply. Then let’s summarize the pros and cons of this approach.
Pros
No need to specify the shape of the input Array when initializing the Module
Parameters are initialized when init is called
Functions and data are separated
Module has no variables that can be changed, and is treated completely separately from its parameters
Disadvantages.
(Only haiku/flax) Module is redundant
Module only has “model settings” such as output dimensions
Direct access to parameter elements is cumbersome
For example, in haiku, each parameter element can be accessed with params["Linear/~/w"], so we need to understand this naming scheme for keys to check parameters
Not very object oriented
Because classes don’t have data
(Only haiku/flax) Parameter calls in Module need to be converted to functions that can be used by jax.grad.
For example, haiku.transform converts functions that include parameter calls with haiku.get_parameter into functions that take parameters as arguments
equinox Features
Equinox is geared toward “more object-oriented” (or PyTorch-like) interface, compared to the init-apply approach. Take a look at Quick example on the top page of the documentation.
While initialization methods such as __init__ are prepared in advance since eqx.Module uses dataclasses.dataclass internally, but let’s override them to initialize parameters. Since \(\sqrt{d_\mathrm{attn}}\) is a constant, let’s make it a member variable as well. Let’s compute gradients:
Code
model = SelfAttention(4, 8, jax.random.PRNGKey(10))jax.grad(lambda model, x: jnp.mean(model(x)))(model, jnp.ones((3, 4)))
The trouble here is that the partial derivative in sqrt_d_attn is also computed. The fact that eqx.Module itself has a parameter causes partial derivatives to be computed even for member variables that would otherwise be constants. This problem is solved in equinox by using eqx.partition and eqx.is_inexact_array to separate a module instance to a “32-bit floating-point jax. Array or numpy.ndarray” and “other member variables”. Let’s try it.
We can see the struct is divided into one with sqrt_d_attn=None and one withsqrt_d_attn=2.8284270763397217. To summarize, when we want to compute the gradients of w.r.t. all parameters in the module with non-parameter member variables,
Split Module into parameters and others.
Creaate a function f(module, ...) of which you want the gradient.
Create a function that wraps f with separated arguments for paremters and non-parameters. Say, wrapped_f(params, others, ...)
compute the gradient with jax.grad(wrapped_f)(params, others, ...)
It’s tiring. But a good news is that equinox.filter_grad can do all of this tiring procedure for us. You can use it for most cases. Let’s try it.
It returned the gradient correctly. The constant sqrt_d_attn is exactly masked by None. So, if you want to have a constant in a Module, just use a type other than jax.Array or ndarray for now. If you want to have a constant as jax.Array, you can’t filter_value_and_grad directly and you need to manually decompose the struct. So I recommend to just have constants as Python built-in types such as float. If you use jax.jit, they are just compiled to a part of binary code, so it adds no overhead.
Try reinforcement learning
Environment
Now that we have introduced the features of equinox, let’s try reinforcement learning with it. Since we are using jax and since the API of Gym has changed a lot and changed to gymnasium we can’t keep up with it at all, let’s use an environment made by jax. Here I use Maze from the jumanji.
It looks easy to use. However, the default size of the video was too big, so I replaced the HTML tags to make it smaller. It seems to be usable in combination with vmap and jit for actual learning.
Let’s implement PPO.
So, let’s try to learn this environment. Here we will implement PPO, because it’s fast.
Input
Input as an array with 3x10x10 shape, each representing a 10x10 binary image of the position of the wall, the position of the agent, and the position of the goal, respectively.
I use a simple network with a 2D convolution layer, followed by ReLU activations and two linear layers. Policy and value networks share the first two layers, and the policy is modeled by a categorical distribution. To make things simpler, I will write the input and output sizes as \(3 \times 10 \times 10\) and \(4\) for the number of actions.
Code
from typing import NamedTuplefrom jax.nn.initializers import orthogonalclass PPONetOutput(NamedTuple): policy_logits: jax.Array value: jax.Arrayclass SoftmaxPPONet(eqx.Module): torso: list value_head: eqx.nn.Linear policy_head: eqx.nn.Lineardef__init__(self, key: jax.Array) ->None: key1, key2, key3, key4, key5 = jax.random.split(key, 5)# Common layersself.torso = [ eqx.nn.Conv2d(3, 1, kernel_size=3, key=key1), jax.nn.relu, jnp.ravel, eqx.nn.Linear(64, 64, key=key2), jax.nn.relu, ]self.value_head = eqx.nn.Linear(64, 1, key=key3) policy_head = eqx.nn.Linear(64, 4, key=key4)# Use small value for policy initializationself.policy_head = eqx.tree_at(lambda linear: linear.weight, policy_head, orthogonal(scale=0.01)(key5, policy_head.weight.shape), )def__call__(self, x: jax.Array) -> PPONetOutput:for layer inself.torso: x = layer(x) value =self.value_head(x) policy_logits =self.policy_head(x)return PPONetOutput(policy_logits=policy_logits, value=value)def value(self, x: jax.Array) -> jax.Array:for layer inself.torso: x = layer(x)returnself.value_head(x)
Rollout
In PPO implementations, it is common to collect a history of environmental interaction about 500~8000 steps and use it to update the network several times. Here we use jax.lax.scan to implement a rollout that is faster than native Python for loop. While scan is super fast, its usage requires a bit of care. In particular, if you set the second output of the function that moves forward one step to results: list[Result], keep in mind that the final return will be Result(member1=stack([m1 for m1 in results.member1]), ...) Also, since the argument of exec_rollout contains a SoftmaxPPONet which is an instance of eqx.Module, you can’t use jax.jit directly. Instead, you need to use eqx.filter_jit that decomposes the module to parameters and non-parameters automatically. Actions are sampled by a categorical distribution obtained by applying softmax to the output of the policy network, but since Observation contains an action_mask that tells you the direction in which you can move in the maze, you can mask the actions you cannot take with this by applying -inf to policy logits. In a simple environment, you don’t need to do this, but mazes with many walls are difficult to solve without this masking.
Let’s test it. The advantage of the jax environment is that you can easily vector-parallelize the environment with jax.vmap, so this time we will run it in 16 parallel. If you apply vmap to reset and give it 16 PRNGKeys, you will get 16 parallel State.
We can confirm that the inputs are 16 in parallel, and each member of Rollout has n. steps x n. env x .... shape.
Learning
Now that the data has been collected, it is time to write the code to update the network. First, compute GAE. Since it is unexpectedly a bottleneck, let’s speed it up with fori_loop.
Looks OK. Now we can sample a mini-batch from the Batch we created with this and update it with gradient descent to minimize the loss function. Here I use optax, which is a standard in the jax community. At this point, note the following three points.
Use eqx.filter_grad instead of jax.grad as explained in the previous sections
SoftmaxPPONet has types that cannot be used as jax.jit arguments, such as jax.nn.relu, so use eqx.partition to break them up before using them as jax.lax.scan arguments
Similarly, SoftmaxPPONet cannot be used as an argument to optax initialization/update functions as is, so it should be broken down with eqx.partition or eqx.filter to exclude members other than jax.Array.
Also, if you want to use jax.lax.scan in a mini-batch update loop, there are several ways to do it. What I do here is that, assuming mini-batch size \(N\), number of updates \(M\), number of update epochs \(K\) times, overall batch size \(N \times M\),
Make permutations of \(0, 1, 2, ... , NM - 1\)\(K\) times.
Make \(K\) copies of Batch shuffled by the permutations we made
Concatenate each member with jnp.concatenate and reshape to an array with size \(MK \times N \times ...\).
It’s a bit difficult, and you may not need to speed it up this fast. If you are worried about memory usage, you could write a \(K\) loop in Python, but in this case, the input is \(3\times 10\times 10\), so it seems to be ok.
So now that we have all the components, let’s start learning. First, let’s try a simple maze with no walls. I copied and pasted junmanji’s default maze generator and rewrote it. In this environment, the reward is given only at the goal, so the average return per episode is simply calculated by \(\frac{\sum R}{\mathrm{Num.~episodes}}\), which is used as a progress indicator. I selected each hyperparameter from my experience.
Mean episodic return: 0.40782122905027934
Mean episodic return: 0.967741935483871
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Elapsed time: 3.2s
About 80,000 steps were learned in a little over 3 seconds. That’s fast. Let’s see what the trained agent looks like.
As expected, this one looks easy enough. Next time, we would like to make it a little more difficult. The default maze, which is generated completely at random, is very slow, so let’s make our own maze. The maze samples the start and the goal from 3 different locations with the same probability, and solves a total of 9 combinations. Let’s try to train the maze with 800,000 steps, which is 10 times as many as the default maze.
Mean episodic return: 0.2812202097235462
Mean episodic return: 0.4077407740774077
Mean episodic return: 0.5992010652463382
Mean episodic return: 0.7285136501516684
Mean episodic return: 0.75
Mean episodic return: 0.8620564808110065
Mean episodic return: 0.9868290258449304
Mean episodic return: 0.9977788746298124
Mean episodic return: 0.9970540974825924
Elapsed time: 1.2e+01s
There is some hesitation, but it looks like we are getting to the goal.
Summary
In this blog, I introduced equinox, a library for handling neural networks with jax, and showed an example of a fast PPO implementation using equinox. Through writing this post, I found eqx.filter_jit very useful. Since I already get used to it, I feel a bit more frastration to write @partial(jax.jit, static_argnums=(2, 3)) every time. However, if you want to use jax.lax.scan as an argument, you need to split it manually with eqx.partition, which was a bit troublesome.