165 lines
5.4 KiB
Python
165 lines
5.4 KiB
Python
"""AstrAI promo: Paged KV Cache — astrai/inference/cache.py & scheduler.py."""
|
|
|
|
from manim import *
|
|
|
|
Text.set_default(font="Times New Roman")
|
|
|
|
|
|
class _TaskRow:
|
|
"""Manages one task's logical-page row: label, blocks, arrows."""
|
|
|
|
COLS = [-4.0, -3.2, -2.4, -1.6]
|
|
|
|
def __init__(self, scene, label, color, y, pool_pos, pool_y):
|
|
self.scene = scene
|
|
self.color = color
|
|
self.y = y
|
|
self.pool_pos = pool_pos
|
|
self.pool_y = pool_y
|
|
self._next_col = 0
|
|
self.blocks = VGroup()
|
|
self.arrows = VGroup()
|
|
|
|
lbl = scene._small(label, 11, color, weight=BOLD)
|
|
lbl.move_to([-5.2, y, 0])
|
|
self._label = lbl
|
|
|
|
self.blocks.add(lbl)
|
|
|
|
def arrive(self, *phys_idxs):
|
|
self.scene.play(Write(self._label))
|
|
for pid in phys_idxs:
|
|
self._add(pid, expand=False)
|
|
|
|
def expand(self, phys_idx):
|
|
self._add(phys_idx, expand=True)
|
|
|
|
def _add(self, phys_idx, expand):
|
|
col = self._next_col
|
|
self._next_col += 1
|
|
x = self.COLS[col]
|
|
pos = np.array([x, self.y, 0])
|
|
|
|
pb = self.scene._lp_box(pos, str(col), self.color)
|
|
self.blocks.add(pb)
|
|
|
|
arr = Arrow(
|
|
[x, self.y + 0.19, 0],
|
|
[self.pool_pos[phys_idx][0], self.pool_y - 0.22, 0],
|
|
color=self.color, stroke_width=1.5, buff=0.03,
|
|
max_tip_length_to_length_ratio=0.12,
|
|
)
|
|
self.arrows.add(arr)
|
|
|
|
if expand:
|
|
target = pb.copy()
|
|
pb.scale(0)
|
|
self.scene.add(pb)
|
|
self.scene.play(Transform(pb, target), GrowArrow(arr), run_time=0.3)
|
|
else:
|
|
self.scene.play(FadeIn(pb, scale=0.5), GrowArrow(arr), run_time=0.12)
|
|
|
|
def finish(self):
|
|
self.scene.play(FadeOut(self.blocks), FadeOut(self.arrows))
|
|
|
|
|
|
class PagedCache(Scene):
|
|
def _small(self, text, size=10, color=GRAY, **kwargs):
|
|
return Text(text, font_size=size, color=color, **kwargs)
|
|
|
|
def _page_box(self, pos, label, color, sz=0.44):
|
|
s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6)
|
|
s.move_to(pos)
|
|
lbl = self._small(label, 10, color).move_to(pos)
|
|
return VGroup(s, lbl)
|
|
|
|
def _lp_box(self, pos, label, color, sz=0.38):
|
|
s = RoundedRectangle(width=sz, height=sz, corner_radius=0.06,
|
|
color=color, fill_opacity=0.22, stroke_width=1.6)
|
|
s.move_to(pos)
|
|
lbl = self._small(label, 12, color).move_to(pos)
|
|
return VGroup(s, lbl)
|
|
|
|
def construct(self):
|
|
title = Text("Paged KV Cache", font_size=20, color=BLUE)
|
|
title.to_edge(UP, buff=0.15)
|
|
self.play(Write(title))
|
|
|
|
pool_y = 1.45; pool_x0 = -3.8; sp = 0.68
|
|
pool_pages = []; pool_pos = []
|
|
for i in range(8):
|
|
x = pool_x0 + i * sp
|
|
pos = np.array([x, pool_y, 0])
|
|
pool_pos.append(pos)
|
|
pb = self._page_box(pos, str(i), GRAY)
|
|
pool_pages.append(pb)
|
|
self.play(FadeIn(pb, scale=0.5), run_time=0.04)
|
|
self.wait(0.1)
|
|
|
|
plbl = self._small("page frames [0..7]", 9, GRAY)
|
|
plbl.next_to(pool_pages[0][0], DOWN, buff=0.25).shift(LEFT * 0.3)
|
|
self.play(Write(plbl))
|
|
|
|
mask = self._small("free: 11111111", 10, GRAY)
|
|
mask.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT)
|
|
self.play(Write(mask))
|
|
|
|
def alloc(idx, color):
|
|
pg = pool_pages[idx][0]
|
|
self.play(pg.animate.set_fill(color, opacity=0.35), run_time=0.1)
|
|
flash = SurroundingRectangle(pool_pages[idx], color=color, buff=0.04)
|
|
self.play(Create(flash), run_time=0.05)
|
|
self.play(FadeOut(flash), run_time=0.04)
|
|
|
|
def set_mask(bits):
|
|
m2 = self._small(f"free: {bits}", 10, GRAY)
|
|
m2.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT)
|
|
self.play(Transform(mask, m2))
|
|
|
|
a_y = 0.25; b_y = -0.45; c_y = -1.15
|
|
|
|
# ── A arrives ──
|
|
alloc(0, GREEN); alloc(1, GREEN); set_mask("11111100")
|
|
a = _TaskRow(self, "A", GREEN, a_y, pool_pos, pool_y)
|
|
a.arrive(0, 1)
|
|
|
|
# ── A expands 1 ──
|
|
alloc(4, GREEN); set_mask("11101100")
|
|
a.expand(4)
|
|
|
|
# ── B arrives ──
|
|
alloc(2, ORANGE); alloc(3, ORANGE); set_mask("11100000")
|
|
b = _TaskRow(self, "B", ORANGE, b_y, pool_pos, pool_y)
|
|
b.arrive(2, 3)
|
|
|
|
# ── C arrives ──
|
|
alloc(5, BLUE); alloc(6, BLUE); set_mask("10000000")
|
|
c = _TaskRow(self, "C", BLUE, c_y, pool_pos, pool_y)
|
|
c.arrive(5, 6)
|
|
|
|
# ── A expands 2 ──
|
|
alloc(7, GREEN); set_mask("00000000")
|
|
a.expand(7)
|
|
|
|
# ── A finishes ──
|
|
a.finish()
|
|
for idx in [0, 1, 4, 7]:
|
|
pg = pool_pages[idx][0]
|
|
self.play(pg.animate.set_fill(GRAY, opacity=0.12), run_time=0.08)
|
|
flash = SurroundingRectangle(pool_pages[idx], color=YELLOW, buff=0.04)
|
|
self.play(Create(flash), run_time=0.06)
|
|
self.play(FadeOut(flash), run_time=0.04)
|
|
set_mask("10010011")
|
|
|
|
# ── B expands (reuse) ──
|
|
alloc(0, ORANGE); set_mask("10010010")
|
|
b.expand(0)
|
|
|
|
self.wait(0.5)
|
|
s = Text("Page-table-indirected, O(1) alloc/free, on-demand growth",
|
|
font_size=12, color=GREEN)
|
|
s.move_to([-3.0, -2.0, 0])
|
|
self.play(Write(s))
|
|
self.wait(2)
|
|
self.play(*[FadeOut(m) for m in self.mobjects])
|