最近equinoxというjaxベースの深層学習モデルを定義するライブラリを使ってみたのですが、これが中々いいと思ったので紹介ついでに強化学習してみます。他のjaxベースのライブラリにはDeepmindのhaikuやGoogle Researchのflaxがあります。この2つのライブラリは実際のところあまり変わりはありません。というのも、jaxには試験的に書かれたstaxという深層学習ライブラリのリファレンス実装があり、haikuもflaxもstaxをベースにオブジェクト志向的なModuleを採用したものだからです。あるいは、haikuやflaxは「staxをPyTorchっぽくしたもの」と言ってもいいかもしれません。equinoxのドキュメントにあるCompatibility with init-apply librariesというページでは、これらのライブラリのやり方を「init-applyアプローチ」と呼んで軽く説明しています。これについてざっと見てみましょう。
from typing import NamedTuplefrom jax.nn.initializers import orthogonalclass PPONetOutput(NamedTuple): policy_logits: jax.Array value: jax.Arrayclass SoftmaxPPONet(eqx.Module): torso: list value_head: eqx.nn.Linear policy_head: eqx.nn.Lineardef__init__(self, key: jax.Array) ->None: key1, key2, key3, key4, key5 = jax.random.split(key, 5)# Common layersself.torso = [ eqx.nn.Conv2d(3, 1, kernel_size=3, key=key1), jax.nn.relu, jnp.ravel, eqx.nn.Linear(64, 64, key=key2), jax.nn.relu, ]self.value_head = eqx.nn.Linear(64, 1, key=key3) policy_head = eqx.nn.Linear(64, 4, key=key4)# Use small value for policy initializationself.policy_head = eqx.tree_at(lambda linear: linear.weight, policy_head, orthogonal(scale=0.01)(key5, policy_head.weight.shape), )def__call__(self, x: jax.Array) -> PPONetOutput:for layer inself.torso: x = layer(x) value =self.value_head(x) policy_logits =self.policy_head(x)return PPONetOutput(policy_logits=policy_logits, value=value)def value(self, x: jax.Array) -> jax.Array:for layer inself.torso: x = layer(x)returnself.value_head(x)
ロールアウト
PPOの実装では1000~8000ステップ程度環境で行動した履歴を集めてそれを使って何度かネットワークを更新するのが普通です。ここではjax.lax.scanを使って、Pythonのループを使うより高速なロールアウトを実装します。scanの速度面での恩恵は大きいですが、使い方には少し注意が必要です。特に、1ステップ進める関数の二番目の出力をresults: list[Result] とすると、最終的に返ってくるのがResult(member1=stack([m1 for m1 in results.member1]), ...)になることは把握しておきましょう。 また、exec_rolloutの引数がeqx.ModuleのインスタンスであるSoftmaxPPONetを含んでいるので、eqx.filter_jitでjitしてあげるとうまいことjitできない値を除外してjitしてくれます。 行動は方策ネットワークの出力をsoftmaxしてサンプルしますが、Observationに迷路で移動可能な方向を教えてくれるaction_maskが入っているので、これを使って取れない行動には-infをかけてマスクしておきましょう。簡単な環境ならこれをやらなくても大丈夫だと思うのですが、壁を多いときはこれがないとけっこう難しいです。
Mean episodic return: 0.40782122905027934
Mean episodic return: 0.967741935483871
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Mean episodic return: 1.0
Elapsed time: 3.2s
Mean episodic return: 0.2812202097235462
Mean episodic return: 0.4077407740774077
Mean episodic return: 0.5992010652463382
Mean episodic return: 0.7285136501516684
Mean episodic return: 0.75
Mean episodic return: 0.8620564808110065
Mean episodic return: 0.9868290258449304
Mean episodic return: 0.9977788746298124
Mean episodic return: 0.9970540974825924
Elapsed time: 1.2e+01s