Code
from typing import Dict, List, Optional, Sequence, Tuple, TypeVar
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.text import Annotation
# Some types for annotations
= TypeVar("T")
T
class Array(Sequence[T]):
pass
= Array[float]
Array1 = Array[Array1]
Array2 = Array[Array2]
Array3 = Tuple[float, float]
Point
def a_to_b(
ax: Axes,
a: Point,
b: Point,str = "",
text: str = "normal",
style: **kwargs,
-> Annotation:
) """Draw arrow from a to b. Optionally"""
str, str] = {
STYLE_ALIASES: Dict["normal": "arc3,rad=-0.4",
"self": "arc3,rad=-1.6",
}= {}
arrowkwargs for arrowkey in list(filter(lambda key: key.startswith("arrow"), kwargs)):
5:]] = kwargs.pop(arrowkey)
arrowkwargs[arrowkey[if len(text) > 0:
= dict(
bbox ="round",
boxstyle="w",
fc=arrowkwargs.get("color", "k"),
ec=arrowkwargs.get("alpha", 1.0),
alpha
)else:
= None
bbox return ax.annotate(
text,=b,
xy=a,
xytext=dict(
arrowprops=10,
shrinkA=10,
shrinkB=1.0,
width=6.0,
headwidth=STYLE_ALIASES.get(style, style),
connectionstyle**arrowkwargs,
),=bbox,
bbox**kwargs,
)
class ChainMDP:
"""Chain MDP with N states and two actions."""
str] = ["xkcd:vermillion", "xkcd:light royal blue"]
ACT_COLORS: List[float = 1.2
INTERVAL: float = 0.8
OFFSET: float = 0.5
SHIFT: float = 4.0
HEIGHT:
def __init__(
self,
float]],
success_probs: Sequence[Sequence[float]],
reward_function: Sequence[Sequence[-> None:
) = np.array(success_probs)
success_probs self.n_states = success_probs.shape[0]
assert success_probs.shape[1] == 2
>= 0, np.ones_like(success_probs))
np.testing.assert_almost_equal(success_probs <= 1, np.ones_like(success_probs))
np.testing.assert_almost_equal(success_probs self.p = np.zeros((self.n_states, 2, self.n_states))
for si in range(self.n_states):
= max(0, si - 1), min(self.n_states - 1, si + 1)
left, right # Action 0 is for right
self.p[si][0][right] += success_probs[si][0]
self.p[si][0][si] += 1.0 - success_probs[si][0]
# Action 1 is for left
self.p[si][1][left] += success_probs[si][1]
self.p[si][1][si] += 1.0 - success_probs[si][1]
self.r = np.array(reward_function) # |S| x 2
assert self.r.shape == (self.n_states, 2)
# For plotting
self.circles = []
self.cached_ax = None
def figure_shape(self) -> Tuple[int, int]:
= self.n_states + (self.n_states - 1) * self.INTERVAL + self.OFFSET * 2.5
width = self.HEIGHT
height return width, height
def show(self, title: str = "", ax: Optional[Axes] = None) -> Axes:
if self.cached_ax is not None:
return self.cached_ax
from matplotlib.patches import Circle
= self.figure_shape()
width, height = height / 2 - height / 10
circle_position if ax is None:
= plt.figure(title or "ChainMDP", (width, height))
fig = fig.add_axes([0, 0, 1, 1], aspect=1.0)
ax 0, width)
ax.set_xlim(0, height)
ax.set_ylim(
ax.set_xticks([])
ax.set_yticks([])
def xi(si: int) -> float:
return self.OFFSET + (1.0 + self.INTERVAL) * si + 0.5
self.circles = [
0.5, fc="w", ec="k")
Circle((xi(i), circle_position), for i in range(self.n_states)
]for i in range(self.n_states):
= self.OFFSET + (1.0 + self.INTERVAL) * i + 0.1
x * 0.85, f"State {i}", fontsize=16)
ax.text(x, height
def annon(act: int, prob: float, *args, **kwargs) -> None:
# We don't hold references to annotations (i.e., we treat them immutable)
a_to_b(
ax,*args,
**kwargs,
=self.ACT_COLORS[act],
arrowcolor=f"P: {prob:.02}",
text=prob,
arrowalpha=11,
fontsize
)
for si in range(self.n_states):
self.circles[si])
ax.add_patch(= xi(si)
x # Action 0:
= circle_position + self.SHIFT
y if si < self.n_states - 1 and 1e-3 < self.p[si][0][si + 1]:
= self.p[si][0][si + 1]
p_right
annon(0,
p_right,+ self.SHIFT, y),
(x + 1) - self.SHIFT * 1.2, y - self.SHIFT * 0.3),
(xi(si ="center_baseline",
verticalalignment
)else:
= 0.0
p_right if p_right + 1e-3 < 1.0:
annon(0,
1.0 - p_right,
- self.SHIFT * 1.2, y),
(x + self.SHIFT * 0.5, y - self.SHIFT * 0.1),
(x ="self",
style="bottom",
verticalalignment
)
ax.text(- self.SHIFT * 1.2,
x + self.SHIFT * 1.4,
y f"r({si}, 0): {self.r[si][0]:+.02}",
=self.ACT_COLORS[0],
color=14,
fontsize
)# Action 1:
= circle_position - self.SHIFT
y if 0 < si and 1e-3 < self.p[si][1][si - 1]:
= self.p[si][1][si - 1]
p_left
annon(1,
self.p[si][1][si - 1],
- self.SHIFT * 1.6, y),
(x - 1) + self.SHIFT * 1.4, y + self.SHIFT * 0.2),
(xi(si ="top",
verticalalignment
)else:
= 0.0
p_left if p_left + 1e-3 < 1.0:
annon(1,
1.0 - p_left,
+ self.SHIFT * 0.4, y),
(x - self.SHIFT * 0.45, y + self.SHIFT * 0.1),
(x ="self",
style="top",
verticalalignment
)
ax.text(- self.SHIFT * 1.2,
x - self.SHIFT * 1.4,
y f"r({si}, 1): {self.r[si][1]:+.02}",
=self.ACT_COLORS[1],
color=14,
fontsize
)
for i in range(2):
0.0], [0.0], color=self.ACT_COLORS[i], label=f"Action {i}")
ax.plot([=11, loc="upper right")
ax.legend(fontsizeif len(title) > 0:
0.06, height * 0.9, title, fontsize=18)
ax.text(self.cached_ax = ax
return ax