from matplotlib.patches import Rectangle
@chex.dataclass
class Capsule(Shape):
length: jax.Array
radius: jax.Array
@chex.dataclass
class Segment(Shape):
length: jax.Array
def to_capsule(self) -> Capsule:
return Capsule(
mass=self.mass,
moment=self.moment,
elasticity=self.elasticity,
friction=self.friction,
rgba=self.rgba,
length=self.length,
radius=jnp.zeros_like(self.length),
)
def _length_to_points(length: jax.Array) -> tuple[jax.Array, jax.Array]:
a = jnp.stack((length * -0.5, length * 0.0), axis=-1)
b = jnp.stack((length * 0.5, length * 0.0), axis=-1)
return a, b
@jax.vmap
def _capsule_to_circle_impl(
a: Capsule,
b: Circle,
a_pos: Position,
b_pos: Position,
isactive: jax.Array,
) -> Contact:
# Move b_pos to capsule's coordinates
pb = a_pos.inv_transform(b_pos.xy)
p1, p2 = _length_to_points(a.length)
edge = p2 - p1
s1 = jnp.dot(pb - p1, edge)
s2 = jnp.dot(p2 - pb, edge)
in_segment = jnp.logical_and(s1 >= 0.0, s2 >= 0.0)
ee = jnp.sum(jnp.square(edge), axis=-1, keepdims=True)
# Closest point
# s1 < 0: pb is left to the capsule
# s2 < 0: pb is right to the capsule
# else: pb is in between capsule
pa = jax.lax.select(
in_segment,
p1 + edge * s1 / ee,
jax.lax.select(s1 < 0.0, p1, p2),
)
a2b_normal, dist = normalize(pb - pa)
penetration = a.radius + b.radius - dist
a_contact = pa + a2b_normal * a.radius
b_contact = pb - a2b_normal * b.radius
pos = a_pos.transform((a_contact + b_contact) * 0.5)
xy_zeros = jnp.zeros_like(b_pos.xy)
a2b_normal_rotated = a_pos.replace(xy=xy_zeros).transform(a2b_normal)
# Filter penetration
penetration = jnp.where(isactive, penetration, jnp.ones_like(penetration) * -1)
return Contact(
pos=pos,
normal=a2b_normal_rotated,
penetration=penetration,
elasticity=(a.elasticity + b.elasticity) * 0.5,
friction=(a.friction + b.friction) * 0.5,
)
@chex.dataclass
class ShapeDict:
circle: Circle | None = None
segment: Segment | None = None
capsule: Capsule | None = None
def concat(self) -> Shape:
shapes = [s.to_shape() for s in self.values() if s is not None]
return jax.tree_map(lambda *args: jnp.concatenate(args, axis=0), *shapes)
@chex.dataclass
class StateDict:
circle: State | None = None
segment: State | None = None
capsule: State | None = None
def concat(self) -> None:
states = [s for s in self.values() if s is not None]
return jax.tree_map(lambda *args: jnp.concatenate(args, axis=0), *states)
def offset(self, key: str) -> int:
total = 0
for k, state in self.items():
if k == key:
return total
if state is not None:
total += state.p.batch_size()
raise RuntimeError("Unreachable")
def _get(self, name: str, state: State) -> State | None:
if self[name] is None:
return None
else:
start = self.offset(name)
end = start + self[name].p.batch_size()
return state.get_slice(jnp.arange(start, end))
def update(self, statec: State) -> Self:
circle = self._get("circle", statec)
segment = self._get("segment", statec)
capsule = self._get("capsule", statec)
return self.__class__(circle=circle, segment=segment, capsule=capsule)
ContactFn = Callable[[StateDict], tuple[Contact, Shape, Shape]]
def _pair_outer(x: jax.Array, reps: int) -> jax.Array:
return jnp.repeat(x, reps, axis=0, total_repeat_length=x.shape[0] * reps)
def _pair_inner(x: jax.Array, reps: int) -> jax.Array:
return jnp.tile(x, (reps,) + (1,) * (x.ndim - 1))
def generate_pairs(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Returns two arrays that iterate over all combination of elements in x and y"""
xlen, ylen = x.shape[0], y.shape[0]
return _pair_outer(x, ylen), _pair_inner(y, xlen)
def _circle_to_circle(
shaped: ShapeDict,
stated: StateDict,
) -> tuple[Contact, Circle, Circle]:
circle1, circle2 = tree_map2(generate_self_pairs, shaped.circle)
pos1, pos2 = tree_map2(generate_self_pairs, stated.circle.p)
is_active = jnp.logical_and(*generate_self_pairs(stated.circle.is_active))
contacts = _circle_to_circle_impl(
circle1,
circle2,
pos1,
pos2,
is_active,
)
return contacts, circle1, circle2
def _capsule_to_circle(
shaped: ShapeDict,
stated: StateDict,
) -> tuple[Contact, Capsule, Circle]:
capsule = jax.tree_map(
functools.partial(_pair_outer, reps=shaped.circle.mass.shape[0]),
shaped.capsule,
)
circle = jax.tree_map(
functools.partial(_pair_inner, reps=shaped.capsule.mass.shape[0]),
shaped.circle,
)
pos1, pos2 = tree_map2(generate_pairs, stated.capsule.p, stated.circle.p)
is_active = jnp.logical_and(
*generate_pairs(stated.capsule.is_active, stated.circle.is_active)
)
contacts = _capsule_to_circle_impl(
capsule,
circle,
pos1,
pos2,
is_active,
)
return contacts, capsule, circle
def _segment_to_circle(
shaped: ShapeDict,
stated: StateDict,
) -> tuple[Contact, Segment, Circle]:
segment = jax.tree_map(
functools.partial(_pair_outer, reps=shaped.circle.mass.shape[0]),
shaped.segment,
)
circle = jax.tree_map(
functools.partial(_pair_inner, reps=shaped.segment.mass.shape[0]),
shaped.circle,
)
pos1, pos2 = tree_map2(generate_pairs, stated.segment.p, stated.circle.p)
is_active = jnp.logical_and(
*generate_pairs(stated.segment.is_active, stated.circle.is_active)
)
contacts = _capsule_to_circle_impl(
segment.to_capsule(),
circle,
pos1,
pos2,
is_active,
)
return contacts, segment, circle
_CONTACT_FUNCTIONS = {
("circle", "circle"): _circle_to_circle,
("capsule", "circle"): _capsule_to_circle,
("segment", "circle"): _segment_to_circle,
}
@chex.dataclass
class ContactWithMetadata:
contact: Contact
shape1: Shape
shape2: Shape
outer_index: jax.Array
inner_index: jax.Array
def gather_p_or_v(
self,
outer: _PositionLike,
inner: _PositionLike,
orig: _PositionLike,
) -> _PositionLike:
xy_outer = jnp.zeros_like(orig.xy).at[self.outer_index].add(outer.xy)
angle_outer = jnp.zeros_like(orig.angle).at[self.outer_index].add(outer.angle)
xy_inner = jnp.zeros_like(orig.xy).at[self.inner_index].add(inner.xy)
angle_inner = jnp.zeros_like(orig.angle).at[self.inner_index].add(inner.angle)
return orig.__class__(angle=angle_outer + angle_inner, xy=xy_outer + xy_inner)
@chex.dataclass
class ExtendedSpace:
gravity: jax.Array
shaped: ShapeDict
dt: jax.Array | float = 0.1
linear_damping: jax.Array | float = 0.95
angular_damping: jax.Array | float = 0.95
bias_factor: jax.Array | float = 0.2
n_velocity_iter: int = 8
n_position_iter: int = 2
linear_slop: jax.Array | float = 0.005
max_linear_correction: jax.Array | float = 0.2
allowed_penetration: jax.Array | float = 0.005
bounce_threshold: float = 1.0
def check_contacts(self, stated: StateDict) -> ContactWithMetadata:
contacts = []
for (n1, n2), fn in _CONTACT_FUNCTIONS.items():
if stated[n1] is not None and stated[n2] is not None:
contact, shape1, shape2 = fn(self.shaped, stated)
len1, len2 = stated[n1].p.batch_size(), stated[n2].p.batch_size()
offset1, offset2 = stated.offset(n1), stated.offset(n2)
if n1 == n2:
outer_index, inner_index = generate_self_pairs(jnp.arange(len1))
else:
outer_index, inner_index = generate_pairs(
jnp.arange(len1),
jnp.arange(len2),
)
contact_with_meta = ContactWithMetadata(
contact=contact,
shape1=shape1.to_shape(),
shape2=shape2.to_shape(),
outer_index=outer_index + offset1,
inner_index=inner_index + offset2,
)
contacts.append(contact_with_meta)
return jax.tree_map(lambda *args: jnp.concatenate(args, axis=0), *contacts)
def n_possible_contacts(self) -> int:
n = 0
for n1, n2 in _CONTACT_FUNCTIONS.keys():
if self.shaped[n1] is not None and self.shaped[n2] is not None:
len1, len2 = len(self.shaped[n1].mass), len(self.shaped[n2].mass)
if n1 == n2:
n += len1 * (len1 - 1) // 2
else:
n += len1 * len2
return n
def animate_balls_and_segments(
fig,
ax: Axes,
circles: Circle,
segments: Segment,
c_pos: Iterable[Position],
s_pos: Position,
) -> HTML:
camera = Camera(fig)
circle_list = circles.tolist()
# Lower left
segment_ll = s_pos.transform(
jnp.stack((-segments.length * 0.5, jnp.zeros_like(segments.length)), axis=1)
)
for pi in c_pos:
for pij, circle in zip(pi.tolist(), circle_list):
circle_patch = CirclePatch(
xy=pij.xy,
radius=circle.radius,
fill=False,
color=circle.rgba.tolist(),
)
ax.add_patch(circle_patch)
for ll, pj, segment in zip(segment_ll, s_pos.tolist(), segments.tolist()):
rect_patch = Rectangle(
xy=ll,
width=segment.length,
angle=(pj.angle / jnp.pi).item() * 180,
height=0.1,
)
ax.add_patch(rect_patch)
camera.snap()
return HTML(camera.animate().to_jshtml())
def solve_constraints(
space: Space,
solver: VelocitySolver,
p: Position,
v: Velocity,
contact_with_meta: ContactWithMetadata,
) -> tuple[Velocity, Position, VelocitySolver]:
"""Resolve collisions by Sequential Impulse method"""
outer, inner = contact_with_meta.outer_index, contact_with_meta.inner_index
def get_pairs(p_or_v: _PositionLike) -> tuple[_PositionLike, _PositionLike]:
return p_or_v.get_slice(outer), p_or_v.get_slice(inner)
p1, p2 = get_pairs(p)
v1, v2 = get_pairs(v)
helper = init_contact_helper(
space,
contact_with_meta.contact,
contact_with_meta.shape1,
contact_with_meta.shape2,
p1,
p2,
v1,
v2,
)
# Warm up the velocity solver
solver = apply_initial_impulse(
contact_with_meta.contact,
helper,
solver.replace(v1=v1, v2=v2),
)
def vstep(
_n_iter: int,
vs: tuple[Velocity, VelocitySolver],
) -> tuple[Velocity, VelocitySolver]:
v_i, solver_i = vs
solver_i1 = apply_velocity_normal(contact_with_meta.contact, helper, solver_i)
v_i1 = contact_with_meta.gather_p_or_v(solver_i1.v1, solver_i1.v2, v_i) + v_i
v1, v2 = get_pairs(v_i1)
return v_i1, solver_i1.replace(v1=v1, v2=v2)
v, solver = jax.lax.fori_loop(0, space.n_velocity_iter, vstep, (v, solver))
bv1, bv2 = apply_bounce(contact_with_meta.contact, helper, solver)
v = contact_with_meta.gather_p_or_v(bv1, bv2, v) + v
def pstep(
_n_iter: int,
ps: tuple[Position, PositionSolver],
) -> tuple[Position, PositionSolver]:
p_i, solver_i = ps
solver_i1 = correct_position(
space.bias_factor,
space.linear_slop,
space.max_linear_correction,
contact_with_meta.contact,
helper,
solver_i,
)
p_i1 = contact_with_meta.gather_p_or_v(solver_i1.p1, solver_i1.p2, p_i) + p_i
p1, p2 = get_pairs(p_i1)
return p_i1, solver_i1.replace(p1=p1, p2=p2)
pos_solver = PositionSolver(
p1=p1,
p2=p2,
contact=solver.contact,
min_separation=jnp.zeros_like(p1.angle),
)
p, pos_solver = jax.lax.fori_loop(0, space.n_position_iter, pstep, (p, pos_solver))
return v, p, solver
def dont_solve_constraints(
_space: Space,
solver: VelocitySolver,
p: Position,
v: Velocity,
_contact_with_meta: ContactWithMetadata,
) -> tuple[Velocity, Position, VelocitySolver]:
return v, p, solver
N_SEG = 3
segments = Segment(
mass=jnp.ones(N_SEG) * jnp.inf,
moment=jnp.ones(N_SEG) * jnp.inf,
elasticity=jnp.ones(N_SEG) * 0.5,
friction=jnp.ones(N_SEG) * 1.0,
rgba=jnp.ones((N_SEG, 4)),
length=jnp.array([4 * jnp.sqrt(2), 4, 4 * jnp.sqrt(2)]),
)
cpos = jnp.array([[2, 2], [4, 3], [3, 6], [6, 5], [5, 7]], dtype=jnp.float32)
stated = StateDict(
circle=State(
p=Position(xy=cpos, angle=jnp.zeros(N)),
v=Velocity.zeros(N),
f=Force.zeros(N),
is_active=jnp.array([True, True, True, True, True]),
),
segment=State(
p=Position(
xy=jnp.array([[-2.0, 2.0], [2, 0], [6, 2]], dtype=jnp.float32),
angle=jnp.array([jnp.pi * 1.75, 0, jnp.pi * 0.25]),
),
v=Velocity.zeros(N_SEG),
f=Force.zeros(N_SEG),
is_active=jnp.ones(N_SEG, dtype=bool),
),
)
space = ExtendedSpace(
gravity=jnp.array([0.0, -9.8]),
linear_damping=1.0,
angular_damping=1.0,
dt=0.04,
bias_factor=0.2,
n_velocity_iter=6,
shaped=ShapeDict(circle=circles, segment=segments),
)
@jax.jit
def step(stated: StateDict, solver: VelocitySolver) -> StateDict:
state = update_velocity(space, space.shaped.concat(), stated.concat())
contact_with_meta = space.check_contacts(stated.update(state))
# Check there's any penetration
contacts = contact_with_meta.contact.penetration >= 0
v, p, solver = jax.lax.cond(
jnp.any(contacts),
solve_constraints,
dont_solve_constraints,
space,
solver.update(contacts),
state.p,
state.v,
contact_with_meta,
)
statec = update_position(space, state.replace(v=v, p=p))
return stated.update(statec)