equinoxで強化学習してみる

ja
RL
deep
Published

June 9, 2023

最近equinoxというjaxベースの深層学習モデルを定義するライブラリを使ってみたのですが、これが中々いいと思ったので紹介ついでに強化学習してみます。他のjaxベースのライブラリにはDeepmindのhaikuやGoogle Researchのflaxがあります。この2つのライブラリは実際のところあまり変わりはありません。というのも、jaxには試験的に書かれたstaxという深層学習ライブラリのリファレンス実装があり、haikuもflaxもstaxをベースにオブジェクト志向的なModuleを採用したものだからです。あるいは、haikuやflaxは「staxをPyTorchっぽくしたもの」と言ってもいいかもしれません。equinoxのドキュメントにあるCompatibility with init-apply librariesというページでは、これらのライブラリのやり方を「init-applyアプローチ」と呼んで軽く説明しています。これについてざっと見てみましょう。

Jaxにできることとできないこと

そもそも、jaxというのは何をしてくれるライブラリなのでしょうか。ホームページのトップにはこう書かれています。

JAX is Autograd and XLA, brought together for high-performance numerical computing.

XLAというのは、Tensorflowのバックエンドとして開発された深層学習用の中間言語で、CPU/GPU/TPU用に数値計算コードを最適化してくるものです。深層学習に求められる並列化の性質から、特にSIMD演算/ベクトル並列化に特化しています。Jaxは、jax.numpyというNumPyに似せたライブラリをXLAのフロントエンドとして提供することで、「NumPyコードをベクトル並列化された高速なGPU用コードに実行時コンパイルすること」を可能にしています。では、Autogradというのはなんでしょうか?これは、jaxの開発者が以前に開発していたライブラリの名前でもありますが、自動微分全般のことを指すと考えていいでしょう。自分で勾配逆伝播のコードを書かなくても、jaxでは損失関数の各パラメタにおいての偏微分を勝手に計算してくれます。試しに、\(f(x, y) = x^2 + y\)の偏微分を計算してみましょう。

Code
import jax
import jax.numpy as jnp

def f(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.sum(x ** 2 + y)

x = jnp.array([3, 4, 5, 6, 7], dtype=jnp.float32)
y = jnp.array([5, 2, 5, 7, 2], dtype=jnp.float32)
jax.grad(f, argnums=(0, 1))(x, y)
(Array([ 6.,  8., 10., 12., 14.], dtype=float32),
 Array([1., 1., 1., 1., 1.], dtype=float32))

\(\frac{\partial}{\partial x}f = 2x, \frac{\partial}{\partial y}f = 1\)なので、正しく計算されているようですね。こんな感じで、jaxは自動的に偏微分を計算してくれるので、これをそのまま勾配降下法に使ってモデルを学習させることができます。深層ニューラルネットワークを学習させる際も、このgradを使えば全部のパラメタについて偏微分を効率的に計算してくれるので、それを使って学習できます。ついでに、この勾配をとる計算やモデルを更新する計算をjax.jitにつっこめば、高速に計算してくれます。 ここで問題になるのは、深層ニューラルネットワークの場合パラメタが多すぎるので、こんな風にいちいち全てのパラメタについて変数を割り当ててプログラムを書いていたら大変すぎる、ということです。大変すぎる以外に何か問題はあるのかというと、特にないと思います。コードの再利用性くらいでしょうか。しかし、まあ面倒なものは面倒ですから、パラメタを管理してくれる仕組みがほしいなあ、と思うわけですね。

init-applyアプローチによるパラメタ管理

stax, haiku, flaxでは、この「パラメタ管理の問題」を、「init-applyアプローチ」により解決しています。このアプローチは以下のようにまとめられます。

  1. モデルはパラメータを持たない
  2. モデルはinitapplyの2つの関数を持つ
  • initは、入力例を受け取ってパラメタを初期化し、最初のパラメータを返す
  • applyは、入力とパラメタを受け取り、モデルの計算結果を返す

なので、flaxやhaikuのAPIは以下のような感じになります。flaxでは__call__を使うがhaikuはPyTorchと同じforwardを使う、flaxのModuleは dataclasses.dataclassデコレータにより定義されたクラスと同じような性質を持つなどの違いがありますが、まあそれくらいで、大して違いはないです。以下僕がflaxとhaikuの間をとって書いた適当な疑似コードです。ニューラルネットワークを表すクラスとして、PyTorch風に「Module」という名前を使っています。

class Linear(Module):
    def __call__(self, x):
        batch_size = x.shape[0]
        w = self.parameter(output, shape=(batch_size, self.output_size) init=self.w_init)
        b = self.parameter(output, shape=(1, self.output_size) init=self.w_init)
        return w @ x.T + b

model = Linear(output_size=10)
params = model.init(jnp.zeros(1, 100))
result = model.apply(params, jnp.zeros(10, 100))

こんな感じですかね。なお、上疑似コード中のself.parameterというメソッドはflaxhaikuにある「パラメータをクラスに登録する機能」のことです。この機能により、各パラメタを値としてもつdictinitにより返すことができます。haikuの場合は、paramsは以下のようなdictになっています。

{
    "Linear/~/w": [...],
    "Linear/~/b": [...],
}

staxはただのreference implementationなのでこのような機能がなく、かわりに、数のレイヤーを組み合わせるコンビネーターを提供しています。ではこのアプローチにはどのようなメリット、デメリットがあるでしょうか。

メリット

  • Moduleを初期化する際に入力されるArrayのshapeを指定しなくていい
    • パラメタはinitが呼ばれた際に初期化される
  • 関数とデータを分離できる
    • Moduleは変更可能な変数を持たず、パラメタと完全に別に扱われる
  • initapplyjitvmapなどの関数デコレータを適用するコードが自然に書ける

デメリット

  • Moduleは冗長
    • Moduleは「出力の次元」などの「モデルに関する設定」を持っているだけ
  • パラメタの要素に直接アクセスするのが面倒
    • 例えばhaikuならparams["Linear/~/w"]のようにしてパラメタの各要素にアクセスできるが、複雑なクラスだとdictの鍵の名前が長くなりわかりにくい
  • あまりオブジェクト指向的ではない
  • (haiku/flaxに特有) Module内でのパラメタの呼び出しをjax.gradが使えるような関数に変換する必要がある
    • 例えばhaiku.transformは、haiku.get_parameterによるパラメタ呼び出しを含む関数を、パラメタを引数としてとる関数に変換する

こんなところでしょうか。

equinoxの特徴

equinoxの特徴は、init-applyアプローチと異なり、「よりオブジェクト指向的な」(または、PyTorchに近い)インターフェースを志向している点にあります。 ドキュメントのトップページにあるQuick exampleを見てみましょう。

Code
import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias
    
@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)
print("Model weight?", model.weight)
print("Grad?", grads)
Model weight? [[0.59902626 0.2172144 ]
 [0.660603   0.03266738]
 [1.2164948  1.1940813 ]]
Grad? Linear(weight=f32[3,2], bias=f32[3])

ここで、equinoxの最大の特徴は

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

というコードに表れていますが、「Moduleがパラメタを直接持つ」という点です。init-applyアプローチと比べると、これはどのような利点・欠点があるでしょうか?

メリット

  • わかりやすい
  • デバッグしやすい
    • flax/haikuは一度transformしないとデバッグできない
    • net = Module(...); net(input)のように短い行数でテストコードが書ける
  • パラメタを直接操作するのが簡単

デメリット

  • モデルを初期化する時に、入力する特徴量の次元が必要
  • gradを使う際に、パラメタとその他の変数を分離する必要がある

ここで、最後の「gradを使う際に、パラメタとその他の変数を分離する必要がある」というのは、どういう意味でしょうか?例えば、self-attentionを計算する以下のようなModuleを考えます。

Code
class SelfAttention(eqx.Module):
    q: eqx.nn.Linear
    k: eqx.nn.Linear
    v: eqx.nn.Linear
    sqrt_d_attn: float

    def __init__(self, d_in: int, d_attn: int, key: jax.Array) -> None:
        q_key, k_key, v_key = jax.random.split(key, 3)
        self.q = eqx.nn.Linear(d_in, d_attn, key=q_key)
        self.k = eqx.nn.Linear(d_in, d_attn, key=k_key)
        self.v = eqx.nn.Linear(d_in, d_attn, key=v_key)
        self.sqrt_d_attn = float(jnp.sqrt(d_attn))

    def __call__(self, e: jax.Array) -> jax.Array:
        q = jax.vmap(self.q)(e)
        k = jax.vmap(self.k)(e)
        alpha = jax.nn.softmax(q.T @ k / self.sqrt_d_attn, axis=-1)
        return jax.vmap(self.v)(e) @ alpha.T

eqx.Moduleは内部でdataclasses.dataclassを使うので、__init__等の初期化メソッドはあらかじめ用意されていますが、これをオーバーライドしてパラメタの初期化に使います。 \(\sqrt{d_\mathrm{attn}}\)は定数なので、これもメンバ変数にしてしまいましょう。勾配を計算してみます。

Code
model = SelfAttention(4, 8, jax.random.PRNGKey(10))
jax.grad(lambda model, x: jnp.mean(model(x)))(model, jnp.ones((3, 4)))
SelfAttention(
  q=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  k=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  v=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  sqrt_d_attn=f32[]
)

ここで困ったことに、sqrt_d_attnでの偏微分も計算されてしまいました。eqx.Moduleそのものがパラメタを持つことによって、定数であるようなメンバ変数に対しても偏微分が計算されてしまいます。この問題を、equinoxでは、eqx.partitioneqx.is_inexact_arrayを使って、「32bit浮動小数点のjax.Arrayまたはnumpy.ndarray」と「その他のメンバ変数」を分離することにより解決しています。やってみましょう。

Code
eqx.partition(model, eqx.is_inexact_array)
(SelfAttention(
   q=Linear(
     weight=f32[8,4],
     bias=f32[8],
     in_features=4,
     out_features=8,
     use_bias=True
   ),
   k=Linear(
     weight=f32[8,4],
     bias=f32[8],
     in_features=4,
     out_features=8,
     use_bias=True
   ),
   v=Linear(
     weight=f32[8,4],
     bias=f32[8],
     in_features=4,
     out_features=8,
     use_bias=True
   ),
   sqrt_d_attn=None
 ),
 SelfAttention(
   q=Linear(weight=None, bias=None, in_features=4, out_features=8, use_bias=True),
   k=Linear(weight=None, bias=None, in_features=4, out_features=8, use_bias=True),
   v=Linear(weight=None, bias=None, in_features=4, out_features=8, use_bias=True),
   sqrt_d_attn=2.8284270763397217
 ))

sqrt_d_attn=Noneのものと、全てのパラメタがNoneでsqrt_d_attn=2.8284270763397217のものに分離されています。なので、定数を持つModuleに対して勾配を計算したい場合は

  1. Moduleをパラメタとそれ以外に分割
  2. 勾配を求めたい関数f(module, ...)をラップする関数g(params, others, ..)みたいなものを作る
  3. jax.grad(g)(params, others, ...)で勾配を計算

という流れになります。面倒ですね。 長々説明したのですが、これを全部やってくれるのが、equinox.filter_gradです。基本何も考えずにこれを使えばいいです。やってみましょう。

Code
eqx.filter_grad(lambda model, x: jnp.mean(model(x)))(model, jnp.ones((3, 4)))
SelfAttention(
  q=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  k=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  v=Linear(
    weight=f32[8,4],
    bias=f32[8],
    in_features=4,
    out_features=8,
    use_bias=True
  ),
  sqrt_d_attn=None
)

勾配を返してくれました。定数のsqrt_d_attnはきっちりNoneでマスクされています。なので、Moduleで定数を持ちたかったらとりあえずjax.Arrayndarray以外の型にしておけばいいです。じゃあ32bit浮動小数点型のjax.Arrayを定数として持ちたかったらどうすればいいんだというと、filter_value_and_gradを直接使えずめちゃくちゃ面倒になるので、避けたほうが良さげです。ただPython組み込みのfloatboolもjitコンパイルしてしまえばただの定数になるので、パフォーマンス的には気にする必要はないです。

強化学習してみる

環境

一通りequinoxの特徴を紹介したところで、これを使って強化学習してみます。せっかくjaxを使っているのとgymのAPIが変わりまくった上にgymnasiumに変わって全然ついていけないので、jax製の環境を使ってみましょう。ここではjumanjiというライブラリのMazeを使ってみます。

Code
import jumanji
from jumanji.wrappers import AutoResetWrapper
from IPython.display import HTML

env = jumanji.make("Maze-v0")
env = AutoResetWrapper(env)
n_actions = env.action_spec().num_values
key, *keys = jax.random.split(jax.random.PRNGKey(20230720), 11)
state, timestep = env.reset(key)
states = [state]
for key in keys:
    action = jax.random.randint(key=key, minval=0, maxval=n_actions, shape=())
    state, timestep = env.step(state, action)
    states.append(state)
anim = env.animate(states)
HTML(anim.to_html5_video().replace('="1000"', '="640"'))  # Change video size

使いやすそうですね。ただデフォルトだと動画がハチャメチャなサイズだったので適当にHTMLのタグを書き換えて小さくしてます。 実際に学習するときにはvmapjitと組み合わせて使えるようです。

PPOを実装してみる

ということで、この環境を学習してみましょう。ここでは定番かつ学習が高速なPPOを実装してみます。

入力

それぞれ壁の位置、エージェントの位置、ゴールの位置をそれぞれ10x10のバイナリ画像で表現し、3x10x10の配列として入力します。

Code
from jumanji.environments.routing.maze.types import Observation, State

def obs_to_image(obs: Observation) -> jax.Array:
    walls = obs.walls.astype(jnp.float32)
    agent = jnp.zeros_like(walls).at[obs.agent_position].set(1.0)
    target = jnp.zeros_like(walls).at[obs.target_position].set(1.0)
    return jnp.stack([walls, agent, target])

ネットワーク

たたみこんでからReLU + 線形レイヤを2回というシンプルな構成にします。途中までは価値関数と方策は共通でいいでしょう。方策はカテゴリカル分布とします。面倒なので、入力サイズ\(3 \times 10 \times 10\)、行動数\(4\)として入出力サイズをベタ書きしてしまいます。

Code
from typing import NamedTuple

from jax.nn.initializers import orthogonal


class PPONetOutput(NamedTuple):
    policy_logits: jax.Array
    value: jax.Array


class SoftmaxPPONet(eqx.Module):
    torso: list
    value_head: eqx.nn.Linear
    policy_head: eqx.nn.Linear

    def __init__(self, key: jax.Array) -> None:
        key1, key2, key3, key4, key5 = jax.random.split(key, 5)
        # Common layers
        self.torso = [
            eqx.nn.Conv2d(3, 1, kernel_size=3, key=key1),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(64, 64, key=key2),
            jax.nn.relu,
        ]
        self.value_head = eqx.nn.Linear(64, 1, key=key3)
        policy_head = eqx.nn.Linear(64, 4, key=key4)
        # Use small value for policy initialization
        self.policy_head = eqx.tree_at(
            lambda linear: linear.weight,
            policy_head,
            orthogonal(scale=0.01)(key5, policy_head.weight.shape),
        )

    def __call__(self, x: jax.Array) -> PPONetOutput:
        for layer in self.torso:
            x = layer(x)
        value = self.value_head(x)
        policy_logits = self.policy_head(x)
        return PPONetOutput(policy_logits=policy_logits, value=value)

    def value(self, x: jax.Array) -> jax.Array:
        for layer in self.torso:
            x = layer(x)
        return self.value_head(x)

ロールアウト

PPOの実装では1000~8000ステップ程度環境で行動した履歴を集めてそれを使って何度かネットワークを更新するのが普通です。ここではjax.lax.scanを使って、Pythonのループを使うより高速なロールアウトを実装します。scanの速度面での恩恵は大きいですが、使い方には少し注意が必要です。特に、1ステップ進める関数の二番目の出力をresults: list[Result] とすると、最終的に返ってくるのがResult(member1=stack([m1 for m1 in results.member1]), ...)になることは把握しておきましょう。 また、exec_rolloutの引数がeqx.ModuleのインスタンスであるSoftmaxPPONetを含んでいるので、eqx.filter_jitでjitしてあげるとうまいことjitできない値を除外してjitしてくれます。 行動は方策ネットワークの出力をsoftmaxしてサンプルしますが、Observationに迷路で移動可能な方向を教えてくれるaction_maskが入っているので、これを使って取れない行動には-infをかけてマスクしておきましょう。簡単な環境ならこれをやらなくても大丈夫だと思うのですが、壁を多いときはこれがないとけっこう難しいです。

Code
import chex


@chex.dataclass
class Rollout:
    """Rollout buffer that stores the entire history of one rollout"""

    observations: jax.Array
    actions: jax.Array
    action_masks: jax.Array
    rewards: jax.Array
    terminations: jax.Array
    values: jax.Array
    policy_logits: jax.Array


def mask_logits(policy_logits: jax.Array, action_mask: jax.Array) -> jax.Array:
    return jax.lax.select(
        action_mask,
        policy_logits,
        jnp.ones_like(policy_logits) * -jnp.inf,
    )


vmapped_obs2i = jax.vmap(obs_to_image)


@eqx.filter_jit
def exec_rollout(
    initial_state: State,
    initial_obs: Observation,
    env: jumanji.Environment,
    network: SoftmaxPPONet,
    prng_key: jax.Array,
    n_rollout_steps: int,
) -> tuple[State, Rollout, Observation, jax.Array]:
    def step_rollout(
        carried: tuple[State, Observation],
        key: jax.Array,
    ) -> tuple[tuple[State, jax.Array], Rollout]:
        state_t, obs_t = carried
        obs_image = vmapped_obs2i(obs_t)
        net_out = jax.vmap(network)(obs_image)
        masked_logits = mask_logits(net_out.policy_logits, obs_t.action_mask)
        actions = jax.random.categorical(key, masked_logits, axis=-1)
        state_t1, timestep = jax.vmap(env.step)(state_t, actions)
        rollout = Rollout(
            observations=obs_image,
            actions=actions,
            action_masks=obs_t.action_mask,
            rewards=timestep.reward,
            terminations=1.0 - timestep.discount,
            values=net_out.value,
            policy_logits=masked_logits,
        )
        return (state_t1, timestep.observation), rollout

    (state, obs), rollout = jax.lax.scan(
        step_rollout,
        (initial_state, initial_obs),
        jax.random.split(prng_key, n_rollout_steps),
    )
    next_value = jax.vmap(network.value)(vmapped_obs2i(obs))
    return state, rollout, obs, next_value

テストしてみましょう。jax.vmapで簡単に環境をベクトル並列化できるのがjax製環境の利点なので、今回は16並列で動かしてみます。resetvmapしてPRNGKeyを16個突っ込むと勝手に16並列のStateがでてきます。

Code
key, net_key, reset_key, rollout_key = jax.random.split(key, 4)
pponet = SoftmaxPPONet(net_key)
initial_state, initial_timestep = jax.vmap(env.reset)(jax.random.split(reset_key, 16))
next_state, rollout, next_obs, next_value = exec_rollout(
    initial_state,
    initial_timestep.observation,
    env,
    pponet,
    rollout_key,
    512,
)
rollout.rewards.shape
(512, 16)

入力が16並列になっていること、Rolloutの各メンバがステップ数x環境数x…の大きさになっていることが確認できました。

学習

データが集められたので後はネットワークを更新するコードを書きましょう。まずGAEを計算します。意外とボトルネックになるのでfori_loopで高速化しておきましょう。

Code
@chex.dataclass(frozen=True, mappable_dataclass=False)
class Batch:
    """Batch for PPO, indexable to get a minibatch."""

    observations: jax.Array
    action_masks: jax.Array
    onehot_actions: jax.Array
    rewards: jax.Array
    advantages: jax.Array
    value_targets: jax.Array
    log_action_probs: jax.Array

    def __getitem__(self, idx: jax.Array):
        return self.__class__(  # type: ignore
            observations=self.observations[idx],
            action_masks=self.action_masks[idx],
            onehot_actions=self.onehot_actions[idx],
            rewards=self.rewards[idx],
            advantages=self.advantages[idx],
            value_targets=self.value_targets[idx],
            log_action_probs=self.log_action_probs[idx],
        )


def compute_gae(
    r_t: jax.Array,
    discount_t: jax.Array,
    values: jax.Array,
    lambda_: float = 0.95,
) -> jax.Array:
    """Efficiently compute generalized advantage estimator (GAE)"""

    gamma_lambda_t = discount_t * lambda_
    delta_t = r_t + discount_t * values[1:] - values[:-1]
    n = delta_t.shape[0]

    def update(i: int, advantage_t: jax.Array) -> jax.Array:
        t = n - i - 1
        adv_t = delta_t[t] + gamma_lambda_t[t] * advantage_t[t + 1]
        return advantage_t.at[t].set(adv_t)

    advantage_t = jax.lax.fori_loop(0, n, update, jnp.zeros_like(values))
    return advantage_t[:-1]


@eqx.filter_jit
def make_batch(
    rollout: Rollout,
    next_value: jax.Array,
    gamma: float,
    gae_lambda: float,
) -> Batch:
    all_values = jnp.concatenate(
        [jnp.squeeze(rollout.values), next_value.reshape(1, -1)]
    )
    advantages = compute_gae(
        rollout.rewards,
        # Set γ = 0 when the episode terminates
        (1.0 - rollout.terminations) * gamma,
        all_values,
        gae_lambda,
    )
    value_targets = advantages + all_values[:-1]
    onehot_actions = jax.nn.one_hot(rollout.actions, 4)
    _, _, *obs_shape = rollout.observations.shape
    log_action_probs = jnp.sum(
        jax.nn.log_softmax(rollout.policy_logits) * onehot_actions,
        axis=-1,
    )
    return Batch(
        observations=rollout.observations.reshape(-1, *obs_shape),
        action_masks=rollout.action_masks.reshape(-1, 4),
        onehot_actions=onehot_actions.reshape(-1, 4),
        rewards=rollout.rewards.ravel(),
        advantages=advantages.ravel(),
        value_targets=value_targets.ravel(),
        log_action_probs=log_action_probs.ravel(),
    )
Code
batch = make_batch(rollout, next_value, 0.99, 0.95)
batch.advantages.shape, batch.onehot_actions.shape, batch.log_action_probs.shape
((8192,), (8192, 4), (8192,))

\(512 \times 16 = 8192\)なので大丈夫そうですね。あとはこれで作ったBatchからミニバッチをサンプルして、損失関数を最小化するように勾配降下で更新します。jax界隈では定番のoptaxを使いましょう。この時、以下の3点に注意します。

  • 前節で説明したように、jax.gradのかわりにeqx.filter_gradを使う
  • SoftmaxPPONetjax.nn.reluなどjax.jitの引数として使えない型を持っているので、jax.lax.scanの引数にする前にeqx.partitionで分解する
  • 同様に、SoftmaxPPONetはそのままoptaxの初期化・アップデート関数の引数にできないので、eqx.partitionで分解するかeqx.filterjax.Array以外のメンバを除外しておく

また、ミニバッチ更新のループでjax.lax.scanを使いたい場合、いくつか方法があると思うのですが、ここではミニバッチサイズ\(N\)、更新回数\(M\)、更新エポック数\(K\)回、全体のバッチサイズ\(N \times M\)として、

  1. \(0, 1, 2, ..., NM - 1\)の順列を\(K\)個作る
  2. バッチの各要素を1で作った順列のもと並び替えたものを\(K\)個作る
  3. 各メンバをjnp.concatenateでくっつけて大きさ\(MK \times N \times ...\)の配列にreshapeする

という方法を使いました。ちょっと面倒ですし、正直ここまで高速化しなくてもいいかもしれませんね。メモリ使用量が不安な場合は\(K\)のループをPythonで書くのもアリかなと思いますが、今回は入力が\(3\times 10\times 10\)なので大丈夫そうですね。

Code
import optax


def loss_function(
    network: SoftmaxPPONet,
    batch: Batch,
    ppo_clip_eps: float,
) -> jax.Array:
    net_out = jax.vmap(network)(batch.observations)
    # Policy loss
    log_pi = jax.nn.log_softmax(
        jax.lax.select(
            batch.action_masks,
            net_out.policy_logits,
            jnp.ones_like(net_out.policy_logits * -jnp.inf),
        )
    )
    log_action_probs = jnp.sum(log_pi * batch.onehot_actions, axis=-1)
    policy_ratio = jnp.exp(log_action_probs - batch.log_action_probs)
    clipped_ratio = jnp.clip(policy_ratio, 1.0 - ppo_clip_eps, 1.0 + ppo_clip_eps)
    clipped_objective = jnp.fmin(
        policy_ratio * batch.advantages,
        clipped_ratio * batch.advantages,
    )
    policy_loss = -jnp.mean(clipped_objective)
    # Value loss
    value_loss = jnp.mean(0.5 * (net_out.value - batch.value_targets) ** 2)
    # Entropy regularization
    entropy = jnp.mean(-jnp.exp(log_pi) * log_pi)
    return policy_loss + value_loss - 0.01 * entropy


vmapped_permutation = jax.vmap(jax.random.permutation, in_axes=(0, None), out_axes=0)


@eqx.filter_jit
def update_network(
    batch: Batch,
    network: SoftmaxPPONet,
    optax_update: optax.TransformUpdateFn,
    opt_state: optax.OptState,
    prng_key: jax.Array,
    minibatch_size: int,
    n_epochs: int,
    ppo_clip_eps: float,
) -> tuple[optax.OptState, SoftmaxPPONet]:
    # Prepare update function
    dynamic_net, static_net = eqx.partition(network, eqx.is_array)

    def update_once(
        carried: tuple[optax.OptState, SoftmaxPPONet],
        batch: Batch,
    ) -> tuple[tuple[optax.OptState, SoftmaxPPONet], None]:
        opt_state, dynamic_net = carried
        network = eqx.combine(dynamic_net, static_net)
        grad = eqx.filter_grad(loss_function)(network, batch, ppo_clip_eps)
        updates, new_opt_state = optax_update(grad, opt_state)
        dynamic_net = optax.apply_updates(dynamic_net, updates)
        return (new_opt_state, dynamic_net), None

    # Prepare minibatches
    batch_size = batch.observations.shape[0]
    permutations = vmapped_permutation(jax.random.split(prng_key, n_epochs), batch_size)
    minibatches = jax.tree_map(
        # Here, x's shape is [batch_size, ...]
        lambda x: x[permutations].reshape(-1, minibatch_size, *x.shape[1:]),
        batch,
    )
    # Update network n_epochs x n_minibatches times
    (opt_state, updated_dynet), _ = jax.lax.scan(
        update_once,
        (opt_state, dynamic_net),
        minibatches,
    )
    return opt_state, eqx.combine(updated_dynet, static_net)

というわけで、部品が全部できたので学習を回してみましょう。まず何も壁がない簡単な迷路で試してみます。簡単な迷路を作りましょう。 junmanjiのデフォルトの迷路ジェネレーターをコピペして適当に書き換えました。また、この環境ではゴールでのみ報酬が与えられるので、単純に報酬和/終了したエピソードの数 でエピソードあたりの平均リターンが求められます。これを訓練状況の指標として出力しておきましょう。ハイパーパラメータはわりと勘で決めました。

Code
from jumanji.environments.routing.maze.generator import Generator
from jumanji.environments.routing.maze.types import Position, State


class TestGenerator(Generator):
    def __init__(self) -> None:
        super().__init__(num_rows=10, num_cols=10)

    def __call__(self, key: chex.PRNGKey) -> State:
        walls = jnp.zeros((10, 10), dtype=bool)
        agent_position = Position(row=0, col=0)
        target_position = Position(row=9, col=9)

        # Build the state.
        return State(
            agent_position=agent_position,
            target_position=target_position,
            walls=walls,
            action_mask=None,
            key=key,
            step_count=jnp.array(0, jnp.int32),
        )
Code
def run_training(
    key: jax.Array,
    adam_lr: float = 3e-4,
    adam_eps: float = 1e-7,
    gamma: float = 0.99,
    gae_lambda: float = 0.95,
    n_optim_epochs: int = 10,
    minibatch_size: int = 1024,
    n_agents: int = 16,
    n_rollout_steps: int = 512,
    n_total_steps: int = 16 * 512 * 100,
    ppo_clip_eps: float = 0.2,
    **env_kwargs,
) -> SoftmaxPPONet:
    key, net_key, reset_key = jax.random.split(key, 3)
    pponet = SoftmaxPPONet(net_key)
    env = AutoResetWrapper(jumanji.make("Maze-v0", **env_kwargs))
    adam_init, adam_update = optax.adam(adam_lr, eps=adam_eps)
    opt_state = adam_init(eqx.filter(pponet, eqx.is_array))
    env_state, timestep = jax.vmap(env.reset)(jax.random.split(reset_key, 16))
    obs = timestep.observation

    n_loop = n_total_steps // (n_agents * n_rollout_steps)
    return_reporting_interval = 1 if n_loop < 10 else n_loop // 10
    n_episodes, reward_sum = 0.0, 0.0
    for i in range(n_loop):
        key, rollout_key, update_key = jax.random.split(key, 3)
        env_state, rollout, obs, next_value = exec_rollout(
            env_state,
            obs,
            env,
            pponet,
            rollout_key,
            n_rollout_steps,
        )
        batch = make_batch(rollout, next_value, gamma, gae_lambda)
        opt_state, pponet = update_network(
            batch,
            pponet,
            adam_update,
            opt_state,
            update_key,
            minibatch_size,
            n_optim_epochs,
            ppo_clip_eps,
        )
        n_episodes += jnp.sum(rollout.terminations).item()
        reward_sum += jnp.sum(rollout.rewards).item()
        if i > 0 and (i % return_reporting_interval == 0):
            print(f"Mean episodic return: {reward_sum / n_episodes}")
            n_episodes = 0.0
            reward_sum = 0.0
    return pponet
Code
import datetime

started = datetime.datetime.now()
key, training_key = jax.random.split(key)
trained_net = run_training(training_key, n_total_steps=16 * 512 * 10, generator=TestGenerator())
elapsed = datetime.datetime.now() - started
print(f"Elapsed time: {elapsed.total_seconds():.2}s")
Mean episodic return: 0.40782122905027934
Mean episodic return: 0.967741935483871
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Elapsed time: 3.2s

約8万ステップが3秒ちょっとで学習できました。速いですね。学習したエージェントがどんなもんか見てみましょう。

Code
@eqx.filter_jit
def visualization_rollout(
    key: jax.random.PRNGKey,
    pponet: SoftmaxPPONet,
    env: jumanji.Environment,
    n_steps: int,
) -> list[State]:
    def step_rollout(
        carried: tuple[State, Observation],
        key: jax.Array,
    ) -> tuple[tuple[State, jax.Array], State]:
        state_t, obs_t = carried
        obs_image = obs_to_image(obs_t)
        net_out = pponet(obs_image)
        action, _ = sample_action(key, net_out.policy_logits, obs_t.action_mask)
        state_t1, timestep = env.step(state_t, action)
        return (state_t1, timestep.observation), state_t1

    initial_state, timestep = env.reset(key)
    _, states = jax.lax.scan(
        step_rollout,
        (initial_state, timestep.observation),
        jax.random.split(key, n_steps),
    )
    leaves, treedef = jax.tree_util.tree_flatten(states)
    return [initial_state] + [treedef.unflatten(leaf) for leaf in zip(*leaves)]
Code
env = AutoResetWrapper(jumanji.make("Maze-v0", generator=TestGenerator()))
key, eval_key = jax.random.split(key)
states = visualization_rollout(eval_key, trained_net, env, 40)
anim = env.animate(states)
HTML(anim.to_html5_video().replace('="1000"', '="640"'))

さすがにこのくらいなら簡単そうですね。次はもう少し難しくしてみたいところです。デフォルトの完全ランダム生成される迷路はやたら遅かったので、適当に自作してみましょう。スタート、ゴールをそれぞれ3箇所から同確率でサンプルし、合計9通りの組み合わせを解きます。まあこのくらいならたぶんいけるでしょう。10倍の80万ステップ学習させてみます。

Code
class MedDifficultyGenerator(Generator):
    WALLS = [
        [0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
        [0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 0, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 1, 0, 1, 1, 0, 0, 0, 0],
        [1, 0, 1, 0, 1, 1, 0, 1, 1, 1],
        [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
        [1, 1, 1, 0, 0, 1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    ]
    def __init__(self) -> None:
        super().__init__(num_rows=10, num_cols=10)

    def __call__(self, key: chex.PRNGKey) -> State:
        key, config_key = jax.random.split(key)
        walls = jnp.array(self.WALLS).astype(bool)
        agent_cfg, target_cfg = jax.random.randint(config_key, (2,), 0, 2)
        agent_position = jax.lax.switch(
            agent_cfg,
            [
                lambda: Position(row=0, col=0),
                lambda: Position(row=9, col=0),
                lambda: Position(row=0, col=9),
            ]
        )
        target_position = jax.lax.switch(
            target_cfg,
            [
                lambda: Position(row=3, col=9),
                lambda: Position(row=7, col=8),
                lambda: Position(row=7, col=0),
            ]
        )
        # Build the state.
        return State(
            agent_position=agent_position,
            target_position=target_position,
            walls=walls,
            action_mask=None,
            key=key,
            step_count=jnp.array(0, jnp.int32),
        )
Code
import datetime

started = datetime.datetime.now()
key, training_key = jax.random.split(key)
trained_net = run_training(
    training_key,
    n_total_steps=16 * 512 * 100,
    generator=MedDifficultyGenerator(),
)
elapsed = datetime.datetime.now() - started
print(f"Elapsed time: {elapsed.total_seconds():.2}s")
Mean episodic return: 0.2812202097235462
Mean episodic return: 0.4077407740774077
Mean episodic return: 0.5992010652463382
Mean episodic return: 0.7285136501516684
Mean episodic return: 0.75
Mean episodic return: 0.8620564808110065
Mean episodic return: 0.9868290258449304
Mean episodic return: 0.9977788746298124
Mean episodic return: 0.9970540974825924
Elapsed time: 1.2e+01s
Code
env = AutoResetWrapper(jumanji.make("Maze-v0", generator=MedDifficultyGenerator()))
key, eval_key = jax.random.split(key)
states = visualization_rollout(eval_key, trained_net, env, 100)
anim = env.animate(states)
HTML(anim.to_html5_video().replace('="1000"', '="640"'))

やや逡巡が見られますがゴールには行けてるっぽいですね。

まとめ

このブログではjaxでニューラルネットワークを扱うためのライブラリであるequinoxを紹介し、またequinoxを使った高速なPPOの実装例を示しました。haikuやflaxを使っている人には何が嬉しいのかというのが理解してもらえたのではないかと思います。やはりeqx.filter_jitが便利ですね。これに慣れたら@partial(jax.jit, static_argnums=(2, 3))とかいちいち書いてられないです。 ただjax.lax.scanの引数にしたい場合はeqx.partitionで手動で分割する必要がありちょっと面倒だなと思いました。