※ This article is translated from Japanese version with some improvements on code.
Making simulation fast utilizing the power of GPU/TPU is a hot topic in the reinforcement learning community. LLM and RLHF things are getting ‘the thing’ in the community, though. Anyway, it is quite fun to speed up simulation. NVIDIA IsaacSym is quite solid, but if you want to speed up the entire RL pipeline, Jax-based libraries such as brax is useful.
You don’t know Jax? It’s just a NumPy-like library, but it can be compiled into a fast vectorized machine code that works really fast on GPU and TPU. It’s even fast on CPUs. I have written a blog post (Japanese only) before, but the current version of brax is farther improved. It allows you to choose a more accurate method, and it will be tightly integrated with MuJoCo XLA.
However, recently, I wanted a simple 2-D physics simulation, and tried brax. My impression was that not only brax is a complete overkill for my usecase, but also its API is quite difficult to set up a new environment. Loading MJCF is easy, but that was the only easy way to set up the stuff. I wanted to make my environment with some lines of Python code like game-physics engines such as pymunk. So, anyway, I tried to make my own.
Move things by semi-implicit Euler
Dealing with contacts is almost always the most diffucult in physics simulation. So, for now, let’s forget about that. Let’s just make things move. Let \(x, v\) denote the position and velocity, and \(\dot{x} = \frac{dx}{dt}, \dot{v} = \frac{dv}{dt}\) be their differentiations. From the second law, \(F(t) = m \frac{d^2 x}{dt^2}(t)\). Then, for a sufficiently small \(\Delta t\), the discrete time form of updating \(x\) and \(v\) is given:
\[
\begin{align*}
\dot{x}_{t + 1} &= \dot{v}_t \Delta t \\
\dot{v}_{t + 1} &= \frac{F}{m} \Delta t .
\end{align*}
\]
This is unstable and rarely used. The simple trick I use is the Semi-implicit Euler, which just flips the order of update by:
Next, let’s implement collision detection. It’s pretty easy because we have only circles now. We just need to store some information such as location and overwrap to resolve collisions later.
The other thing I need to care is vectorization. Say, this is a naive Python code to check contacts.
for i inrange(N):for j inrange(i +1, N): check_contact(objects[i], objects[j])
It has \(O(N^2)\) loop! However, you might be smart enough to point out that each operation in this loop has no dependency, or the order of computation doesn’t matter. Here, jax.vmap is for you to vectorize those operations. Then, let’s unrole this loop by hand, make pairs, and then call check_contact. Like,
How to make pairs? There are various ways, what I found the simplest is triu_indices, which is a function that produces the indices of upper triangular matrix with some offsets.
Objects are overlapping since I haven’t implemented resolving collision, but the dected collision looks correct.
What to do after collision detection
I’m not quite familiar with physics actually, but in the real world, collisions are automatically resolved by each small molecule. But in the game physics engine, we want to get results faster by enforcing physically acurate constraints, such as:
Speed changes by collision
No overlap between objects
So, I’m just solving these constraints. There are many ways to solve these, but here, I employ a common approach called Sequential Impulse, which is used in a famous physics engines including Chipmunk and Box2D. Other approaches includes recently trending position-based dynamics, but sequatinal impulse caught my eyes because of its simplicity.
Anyway, let me explain how sequential impulse method resolves collision. First, let’s assume a collision model in which collision produces impulse, and let the generated impulse be \(\mathbf{p}\). I don’t use angular velocity here for simplicity. Let the velocity and mass of object 1 be \(\mathbf{v}_1\) and \(m_1\), and the velocity and mass of object 2 be \(\mathbf{v}_2\) and \(m_2\), respectively, then the velocity after the impulse occurs is
Here, because the direction of \(\mathbf{p}\) is the normal vector of collision \(\mathbf{n}\), we can write that \(\mathbf{p} = p\mathbf{n}\). Hence, we just need to compute \(p\). Let the relative velocity at the collision point be \(\Delta \mathbf{v} = \mathbf{v}_2 - \mathbf{v}_1\). Here combining the avove two equalities with \(\Delta \mathbf{v} \cdot n = 0\), \(p = \frac{-\Delta \mathbf{v}^{\mathrm{old}}\cdot \mathbf{n}}{\frac{1}{m_1} + \frac{1}{m_2}}\). Angle makes this a bit more complicated, but the core equation is just like this.