video-promo/paged_cache.py

165 lines
5.4 KiB
Python

"""AstrAI promo: Paged KV Cache — astrai/inference/cache.py & scheduler.py."""
from manim import *
Text.set_default(font="Times New Roman")
class _TaskRow:
"""Manages one task's logical-page row: label, blocks, arrows."""
COLS = [-4.0, -3.2, -2.4, -1.6]
def __init__(self, scene, label, color, y, pool_pos, pool_y):
self.scene = scene
self.color = color
self.y = y
self.pool_pos = pool_pos
self.pool_y = pool_y
self._next_col = 0
self.blocks = VGroup()
self.arrows = VGroup()
lbl = scene._small(label, 11, color, weight=BOLD)
lbl.move_to([-5.2, y, 0])
self._label = lbl
self.blocks.add(lbl)
def arrive(self, *phys_idxs):
self.scene.play(Write(self._label))
for pid in phys_idxs:
self._add(pid, expand=False)
def expand(self, phys_idx):
self._add(phys_idx, expand=True)
def _add(self, phys_idx, expand):
col = self._next_col
self._next_col += 1
x = self.COLS[col]
pos = np.array([x, self.y, 0])
pb = self.scene._lp_box(pos, str(col), self.color)
self.blocks.add(pb)
arr = Arrow(
[x, self.y + 0.19, 0],
[self.pool_pos[phys_idx][0], self.pool_y - 0.22, 0],
color=self.color, stroke_width=1.5, buff=0.03,
max_tip_length_to_length_ratio=0.12,
)
self.arrows.add(arr)
if expand:
target = pb.copy()
pb.scale(0)
self.scene.add(pb)
self.scene.play(Transform(pb, target), GrowArrow(arr), run_time=0.3)
else:
self.scene.play(FadeIn(pb, scale=0.5), GrowArrow(arr), run_time=0.12)
def finish(self):
self.scene.play(FadeOut(self.blocks), FadeOut(self.arrows))
class PagedCache(Scene):
def _small(self, text, size=10, color=GRAY, **kwargs):
return Text(text, font_size=size, color=color, **kwargs)
def _page_box(self, pos, label, color, sz=0.44):
s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6)
s.move_to(pos)
lbl = self._small(label, 10, color).move_to(pos)
return VGroup(s, lbl)
def _lp_box(self, pos, label, color, sz=0.38):
s = RoundedRectangle(width=sz, height=sz, corner_radius=0.06,
color=color, fill_opacity=0.22, stroke_width=1.6)
s.move_to(pos)
lbl = self._small(label, 12, color).move_to(pos)
return VGroup(s, lbl)
def construct(self):
title = Text("Paged KV Cache", font_size=20, color=BLUE)
title.to_edge(UP, buff=0.15)
self.play(Write(title))
pool_y = 1.45; pool_x0 = -3.8; sp = 0.68
pool_pages = []; pool_pos = []
for i in range(8):
x = pool_x0 + i * sp
pos = np.array([x, pool_y, 0])
pool_pos.append(pos)
pb = self._page_box(pos, str(i), GRAY)
pool_pages.append(pb)
self.play(FadeIn(pb, scale=0.5), run_time=0.04)
self.wait(0.1)
plbl = self._small("page frames [0..7]", 9, GRAY)
plbl.next_to(pool_pages[0][0], DOWN, buff=0.25).shift(LEFT * 0.3)
self.play(Write(plbl))
mask = self._small("free: 11111111", 10, GRAY)
mask.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT)
self.play(Write(mask))
def alloc(idx, color):
pg = pool_pages[idx][0]
self.play(pg.animate.set_fill(color, opacity=0.35), run_time=0.1)
flash = SurroundingRectangle(pool_pages[idx], color=color, buff=0.04)
self.play(Create(flash), run_time=0.05)
self.play(FadeOut(flash), run_time=0.04)
def set_mask(bits):
m2 = self._small(f"free: {bits}", 10, GRAY)
m2.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT)
self.play(Transform(mask, m2))
a_y = 0.25; b_y = -0.45; c_y = -1.15
# ── A arrives ──
alloc(0, GREEN); alloc(1, GREEN); set_mask("11111100")
a = _TaskRow(self, "A", GREEN, a_y, pool_pos, pool_y)
a.arrive(0, 1)
# ── A expands 1 ──
alloc(4, GREEN); set_mask("11101100")
a.expand(4)
# ── B arrives ──
alloc(2, ORANGE); alloc(3, ORANGE); set_mask("11100000")
b = _TaskRow(self, "B", ORANGE, b_y, pool_pos, pool_y)
b.arrive(2, 3)
# ── C arrives ──
alloc(5, BLUE); alloc(6, BLUE); set_mask("10000000")
c = _TaskRow(self, "C", BLUE, c_y, pool_pos, pool_y)
c.arrive(5, 6)
# ── A expands 2 ──
alloc(7, GREEN); set_mask("00000000")
a.expand(7)
# ── A finishes ──
a.finish()
for idx in [0, 1, 4, 7]:
pg = pool_pages[idx][0]
self.play(pg.animate.set_fill(GRAY, opacity=0.12), run_time=0.08)
flash = SurroundingRectangle(pool_pages[idx], color=YELLOW, buff=0.04)
self.play(Create(flash), run_time=0.06)
self.play(FadeOut(flash), run_time=0.04)
set_mask("10010011")
# ── B expands (reuse) ──
alloc(0, ORANGE); set_mask("10010010")
b.expand(0)
self.wait(0.5)
s = Text("Page-table-indirected, O(1) alloc/free, on-demand growth",
font_size=12, color=GREEN)
s.move_to([-3.0, -2.0, 0])
self.play(Write(s))
self.wait(2)
self.play(*[FadeOut(m) for m in self.mobjects])