fix paged_cache: _TaskRow class, expand rendering bug, and order corrected
This commit is contained in:
parent
bb0c32b032
commit
4f14d09fe3
590
paged_cache.py
590
paged_cache.py
|
|
@ -1,129 +1,94 @@
|
||||||
"""AstrAI promo: Paged KV Cache — matching astrai/inference/cache.py & scheduler.py."""
|
"""AstrAI promo: Paged KV Cache — astrai/inference/cache.py & scheduler.py."""
|
||||||
|
|
||||||
from manim import *
|
from manim import *
|
||||||
|
|
||||||
Text.set_default(font="Times New Roman")
|
Text.set_default(font="Times New Roman")
|
||||||
|
|
||||||
|
|
||||||
class PrefixCache(Scene):
|
class _TaskRow:
|
||||||
"""Animates PagedCache exact logic: alloc→write→free with real code details."""
|
"""Manages one task's logical-page row: label, blocks, arrows."""
|
||||||
|
|
||||||
def _small_text(self, text, size=10, color=GRAY, **kwargs):
|
COLS = [-4.0, -3.2, -2.4, -1.6]
|
||||||
|
|
||||||
|
def __init__(self, scene, label, color, y, pool_pos, pool_y):
|
||||||
|
self.scene = scene
|
||||||
|
self.color = color
|
||||||
|
self.y = y
|
||||||
|
self.pool_pos = pool_pos
|
||||||
|
self.pool_y = pool_y
|
||||||
|
self._next_col = 0
|
||||||
|
self.blocks = VGroup()
|
||||||
|
self.arrows = VGroup()
|
||||||
|
|
||||||
|
lbl = scene._small(label, 11, color, weight=BOLD)
|
||||||
|
lbl.move_to([-5.2, y, 0])
|
||||||
|
self._label = lbl
|
||||||
|
|
||||||
|
self.blocks.add(lbl)
|
||||||
|
|
||||||
|
def arrive(self, *phys_idxs):
|
||||||
|
self.scene.play(Write(self._label))
|
||||||
|
for pid in phys_idxs:
|
||||||
|
self._add(pid, expand=False)
|
||||||
|
|
||||||
|
def expand(self, phys_idx):
|
||||||
|
self._add(phys_idx, expand=True)
|
||||||
|
|
||||||
|
def _add(self, phys_idx, expand):
|
||||||
|
col = self._next_col
|
||||||
|
self._next_col += 1
|
||||||
|
x = self.COLS[col]
|
||||||
|
pos = np.array([x, self.y, 0])
|
||||||
|
|
||||||
|
pb = self.scene._lp_box(pos, str(col), self.color)
|
||||||
|
self.blocks.add(pb)
|
||||||
|
|
||||||
|
arr = Arrow(
|
||||||
|
[x, self.y + 0.19, 0],
|
||||||
|
[self.pool_pos[phys_idx][0], self.pool_y - 0.22, 0],
|
||||||
|
color=self.color, stroke_width=1.5, buff=0.03,
|
||||||
|
max_tip_length_to_length_ratio=0.12,
|
||||||
|
)
|
||||||
|
self.arrows.add(arr)
|
||||||
|
|
||||||
|
if expand:
|
||||||
|
target = pb.copy()
|
||||||
|
pb.scale(0)
|
||||||
|
self.scene.add(pb)
|
||||||
|
self.scene.play(Transform(pb, target), GrowArrow(arr), run_time=0.3)
|
||||||
|
else:
|
||||||
|
self.scene.play(FadeIn(pb, scale=0.5), GrowArrow(arr), run_time=0.12)
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.scene.play(FadeOut(self.blocks), FadeOut(self.arrows))
|
||||||
|
|
||||||
|
|
||||||
|
class PagedCache(Scene):
|
||||||
|
def _small(self, text, size=10, color=GRAY, **kwargs):
|
||||||
return Text(text, font_size=size, color=color, **kwargs)
|
return Text(text, font_size=size, color=color, **kwargs)
|
||||||
|
|
||||||
def _page_box(self, pos, label, color, sz=0.44):
|
def _page_box(self, pos, label, color, sz=0.44):
|
||||||
s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6)
|
s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6)
|
||||||
s.move_to(pos)
|
s.move_to(pos)
|
||||||
lbl = self._small_text(label, 10, color).move_to(pos)
|
lbl = self._small(label, 10, color).move_to(pos)
|
||||||
return VGroup(s, lbl)
|
return VGroup(s, lbl)
|
||||||
|
|
||||||
def _make_ptable(self, x, y, rows, color, label):
|
def _lp_box(self, pos, label, color, sz=0.38):
|
||||||
"""Draw OS-style page table. Returns (outline, entry_group, phy_cell_centers)."""
|
s = RoundedRectangle(width=sz, height=sz, corner_radius=0.06,
|
||||||
n = len(rows)
|
color=color, fill_opacity=0.22, stroke_width=1.6)
|
||||||
w = 2.6
|
s.move_to(pos)
|
||||||
h = 0.40
|
lbl = self._small(label, 12, color).move_to(pos)
|
||||||
vx = w * 0.40
|
return VGroup(s, lbl)
|
||||||
|
|
||||||
group = VGroup()
|
|
||||||
|
|
||||||
# Label
|
|
||||||
lbl = self._small_text(label, 12, color, weight=BOLD)
|
|
||||||
lbl.move_to([x + w / 2, y + 0.15, 0])
|
|
||||||
group.add(lbl)
|
|
||||||
|
|
||||||
# Outer rect (header+data rows)
|
|
||||||
r = Rectangle(width=w, height=h * (n + 1), color=color, stroke_width=1.8, fill_opacity=0.06)
|
|
||||||
r.move_to([x + w / 2, y - h * (n + 1) / 2, 0])
|
|
||||||
group.add(r)
|
|
||||||
|
|
||||||
# H dividers
|
|
||||||
for ri in range(1, n + 1):
|
|
||||||
ly = y - ri * h
|
|
||||||
group.add(Line([x, ly, 0], [x + w, ly, 0], color=color, stroke_width=0.8))
|
|
||||||
|
|
||||||
# V divider
|
|
||||||
group.add(Line([x + vx, y, 0], [x + vx, y - h * (n + 1), 0], color=color, stroke_width=0.8))
|
|
||||||
|
|
||||||
# Column headers
|
|
||||||
lh = self._small_text("logical", 10, color, weight=BOLD)
|
|
||||||
lh.move_to([x + vx / 2, y - h / 2, 0])
|
|
||||||
ph = self._small_text("physical", 10, color, weight=BOLD)
|
|
||||||
ph.move_to([x + vx + (w - vx) / 2, y - h / 2, 0])
|
|
||||||
group.add(lh, ph)
|
|
||||||
|
|
||||||
# Row entries
|
|
||||||
entries = VGroup()
|
|
||||||
phy_centers = []
|
|
||||||
for i, (lg, phy) in enumerate(rows):
|
|
||||||
cy = y - h / 2 - (i + 1) * h
|
|
||||||
lt = self._small_text(lg, 11, color)
|
|
||||||
lt.move_to([x + vx / 2, cy, 0])
|
|
||||||
pt = self._small_text(phy, 11, color)
|
|
||||||
pt.move_to([x + vx + (w - vx) / 2, cy, 0])
|
|
||||||
entries.add(VGroup(lt, pt))
|
|
||||||
phy_centers.append([x + vx + (w - vx) / 2, cy, 0])
|
|
||||||
|
|
||||||
return group, entries, phy_centers, r
|
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
# ── Scene setup ──
|
title = Text("Paged KV Cache", font_size=20, color=BLUE)
|
||||||
title = self._small_text("Paged KV Cache — astrai/inference/cache.py", 26, BLUE)
|
|
||||||
title.to_edge(UP, buff=0.15)
|
title.to_edge(UP, buff=0.15)
|
||||||
self.play(Write(title))
|
self.play(Write(title))
|
||||||
self.wait(0.1)
|
|
||||||
|
|
||||||
right_x = 4.5
|
pool_y = 1.45; pool_x0 = -3.8; sp = 0.68
|
||||||
step_y = 3.0
|
pool_pages = []; pool_pos = []
|
||||||
|
|
||||||
def step_msg(text, color=YELLOW):
|
|
||||||
nonlocal step_y
|
|
||||||
m = self._small_text(text, 12, color)
|
|
||||||
m.move_to([right_x, step_y, 0])
|
|
||||||
step_y -= 0.55
|
|
||||||
self.play(Write(m))
|
|
||||||
return m
|
|
||||||
|
|
||||||
def fade(m):
|
|
||||||
self.play(FadeOut(m))
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
# Phase 0: Initialize PagedCache
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
s0 = step_msg("PagedCache(max_batch=8, page_size=64, head_dim=64, n_kv_heads=4)")
|
|
||||||
|
|
||||||
# Show constructor signature
|
|
||||||
ctor = self._small_text(
|
|
||||||
'PagedCache(n_layers=24, n_pages=8, page_size=64,\n'
|
|
||||||
' n_kv_heads=4, head_dim=64, device=cuda, dtype=bfloat16)',
|
|
||||||
9, GRAY
|
|
||||||
)
|
|
||||||
ctor.move_to([-5.5, 2.4, 0])
|
|
||||||
self.play(Write(ctor))
|
|
||||||
|
|
||||||
# Show the tensor shape
|
|
||||||
tensor_shape = self._small_text(
|
|
||||||
'k_cache: [24, 8, 64, 4, 64] v_cache: [24, 8, 64, 4, 64]',
|
|
||||||
8, GRAY
|
|
||||||
)
|
|
||||||
tensor_shape.move_to([-5.5, 2.05, 0])
|
|
||||||
self.play(Write(tensor_shape))
|
|
||||||
|
|
||||||
self.wait(0.3)
|
|
||||||
fade(s0)
|
|
||||||
# Keep ctor & tensor_shape visible
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
# Physical page pool
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
s1 = step_msg("8 page frames, refs[8] = [0,0,0,0,0,0,0,0], free_mask = (1<<8)-1 = 0xFF")
|
|
||||||
|
|
||||||
pool_y = 1.45
|
|
||||||
pool_x0 = -3.8
|
|
||||||
pool_sp = 0.68
|
|
||||||
pool_pages = []
|
|
||||||
pool_pos = []
|
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
x = pool_x0 + i * pool_sp
|
x = pool_x0 + i * sp
|
||||||
pos = np.array([x, pool_y, 0])
|
pos = np.array([x, pool_y, 0])
|
||||||
pool_pos.append(pos)
|
pool_pos.append(pos)
|
||||||
pb = self._page_box(pos, str(i), GRAY)
|
pb = self._page_box(pos, str(i), GRAY)
|
||||||
|
|
@ -131,394 +96,69 @@ class PrefixCache(Scene):
|
||||||
self.play(FadeIn(pb, scale=0.5), run_time=0.04)
|
self.play(FadeIn(pb, scale=0.5), run_time=0.04)
|
||||||
self.wait(0.1)
|
self.wait(0.1)
|
||||||
|
|
||||||
# Bracket
|
plbl = self._small("page frames [0..7]", 9, GRAY)
|
||||||
brack = Brace(VGroup(*[p[0] for p in pool_pages]), DOWN, buff=0.05)
|
plbl.next_to(pool_pages[0][0], DOWN, buff=0.25).shift(LEFT * 0.3)
|
||||||
blbl = Text("page frames [0..7] each holds 64 KV slots", font_size=10, color=GRAY)
|
self.play(Write(plbl))
|
||||||
blbl.next_to(brack, DOWN, buff=0.02)
|
|
||||||
self.play(Create(brack), Write(blbl), run_time=0.25)
|
|
||||||
|
|
||||||
# Free bitmask — with bit position labels
|
mask = self._small("free: 11111111", 10, GRAY)
|
||||||
mask_y = 0.55
|
mask.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT)
|
||||||
bit_labels = VGroup()
|
|
||||||
for i in range(8):
|
|
||||||
bl = self._small_text(f"bit{i}", 7, DARK_GRAY)
|
|
||||||
bl.move_to([pool_x0 + i * pool_sp, mask_y + 0.25, 0])
|
|
||||||
bit_labels.add(bl)
|
|
||||||
self.play(Write(bit_labels), run_time=0.2)
|
|
||||||
|
|
||||||
mask = self._small_text("11111111 (1 = free, 0 = alloc)", 11, GRAY)
|
|
||||||
mask.move_to([-1.5, mask_y, 0])
|
|
||||||
self.play(Write(mask))
|
self.play(Write(mask))
|
||||||
self.wait(0.15)
|
|
||||||
|
|
||||||
refs_lbl = self._small_text("refs = [0,0,0,0,0,0,0,0]", 9, GRAY)
|
def alloc(idx, color):
|
||||||
refs_lbl.move_to([-1.5, mask_y - 0.3, 0])
|
pg = pool_pages[idx][0]
|
||||||
self.play(Write(refs_lbl))
|
self.play(pg.animate.set_fill(color, opacity=0.35), run_time=0.1)
|
||||||
|
flash = SurroundingRectangle(pool_pages[idx], color=color, buff=0.04)
|
||||||
|
self.play(Create(flash), run_time=0.05)
|
||||||
|
self.play(FadeOut(flash), run_time=0.04)
|
||||||
|
|
||||||
def update_mask(bits, desc):
|
def set_mask(bits):
|
||||||
m2 = self._small_text(f"{bits} (1 = free, 0 = alloc)", 11, GRAY)
|
m2 = self._small(f"free: {bits}", 10, GRAY)
|
||||||
m2.move_to([-1.5, mask_y, 0])
|
m2.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT)
|
||||||
self.play(Transform(mask, m2))
|
self.play(Transform(mask, m2))
|
||||||
if desc:
|
|
||||||
self.play(Transform(refs_lbl, self._small_text(desc, 9, GRAY).move_to([-1.5, mask_y - 0.3, 0])))
|
|
||||||
|
|
||||||
fade(s1)
|
a_y = 0.25; b_y = -0.45; c_y = -1.15
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
# ── A arrives ──
|
||||||
# Phase 1: Cleanup — nothing to do initially
|
alloc(0, GREEN); alloc(1, GREEN); set_mask("11111100")
|
||||||
# ═══════════════════════════════════════════
|
a = _TaskRow(self, "A", GREEN, a_y, pool_pos, pool_y)
|
||||||
s_phase = self._small_text("Phase 1: Cleanup (no finished tasks)", 11, GRAY)
|
a.arrive(0, 1)
|
||||||
s_phase.move_to([right_x, step_y, 0])
|
|
||||||
step_y -= 0.4
|
|
||||||
self.play(Write(s_phase))
|
|
||||||
self.wait(0.2)
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
# ── A expands 1 ──
|
||||||
# Phase 2: Refill — Request A arrives
|
alloc(4, GREEN); set_mask("11101100")
|
||||||
# ═══════════════════════════════════════════
|
a.expand(4)
|
||||||
step_y2 = step_y
|
|
||||||
s2a = step_msg("Phase 2: Refill — Request A arrives (prompt_len=120 tokens)", GREEN)
|
|
||||||
s2b = step_msg("_n_pages_for(120) = (120 + 64 - 1) // 64 = 183 // 64 = 2", GREEN)
|
|
||||||
|
|
||||||
calc_box = Rectangle(width=4.8, height=1.0, color=GREEN_E, stroke_width=1.2, fill_opacity=0.05)
|
# ── B arrives ──
|
||||||
calc_box.move_to([right_x - 0.3, step_y - 0.1, 0])
|
alloc(2, ORANGE); alloc(3, ORANGE); set_mask("11100000")
|
||||||
calc_lines = VGroup(
|
b = _TaskRow(self, "B", ORANGE, b_y, pool_pos, pool_y)
|
||||||
self._small_text("def _n_pages_for(n_tokens):", 9, GREEN),
|
b.arrive(2, 3)
|
||||||
self._small_text(" return (n_tokens + page_size - 1) // page_size", 9, GREEN),
|
|
||||||
self._small_text("_n_pages_for(120) = (120 + 64 - 1) // 64 = 2", 9, WHITE),
|
|
||||||
)
|
|
||||||
calc_lines.arrange(DOWN, buff=0.08, aligned_edge=LEFT)
|
|
||||||
calc_lines.move_to(calc_box.get_center())
|
|
||||||
calc_grp = VGroup(calc_box, calc_lines)
|
|
||||||
self.play(Create(calc_grp), run_time=0.35)
|
|
||||||
step_y -= 0.6
|
|
||||||
|
|
||||||
# alloc() in action
|
# ── C arrives ──
|
||||||
s2c = step_msg("alloc_n(2) → calls alloc() twice", GREEN)
|
alloc(5, BLUE); alloc(6, BLUE); set_mask("10000000")
|
||||||
|
c = _TaskRow(self, "C", BLUE, c_y, pool_pos, pool_y)
|
||||||
|
c.arrive(5, 6)
|
||||||
|
|
||||||
alloc_code = VGroup(
|
# ── A expands 2 ──
|
||||||
self._small_text("def alloc(self) -> int:", 9, GREEN),
|
alloc(7, GREEN); set_mask("00000000")
|
||||||
self._small_text(" lsb = self._free_mask & -self._free_mask", 9, GREEN),
|
a.expand(7)
|
||||||
self._small_text(" if lsb == 0: return -1", 9, GREEN),
|
|
||||||
self._small_text(" idx = lsb.bit_length() - 1", 9, GREEN),
|
|
||||||
self._small_text(" self._free_mask ^= lsb", 9, GREEN),
|
|
||||||
self._small_text(" self._refs[idx] = 1", 9, GREEN),
|
|
||||||
self._small_text(" return idx", 9, GREEN),
|
|
||||||
self._small_text("", 6, GREEN),
|
|
||||||
self._small_text("1st alloc: free_mask=11111111, lsb=1, idx=0 → page 0", 9, WHITE),
|
|
||||||
self._small_text("2nd alloc: free_mask=11111110, lsb=2, idx=1 → page 1", 9, WHITE),
|
|
||||||
)
|
|
||||||
alloc_code.arrange(DOWN, buff=0.06, aligned_edge=LEFT)
|
|
||||||
acbox = Rectangle(width=5.0, height=2.6, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03)
|
|
||||||
acbox.move_to([right_x - 0.2, step_y - 0.8, 0])
|
|
||||||
alloc_code.move_to(acbox.get_center())
|
|
||||||
alloc_grp = VGroup(acbox, alloc_code)
|
|
||||||
self.play(Create(alloc_grp), run_time=0.4)
|
|
||||||
step_y -= 1.8
|
|
||||||
|
|
||||||
# Now draw page table A (2 rows)
|
# ── A finishes ──
|
||||||
tblA_x = -4.4
|
a.finish()
|
||||||
tblA_y = -0.35
|
for idx in [0, 1, 4, 7]:
|
||||||
tblA_outline, tblA_entries, tblA_phys, tblA_rect = self._make_ptable(
|
|
||||||
tblA_x, tblA_y,
|
|
||||||
[("0", "0"), ("1", "1")],
|
|
||||||
GREEN, "Task A page_table"
|
|
||||||
)
|
|
||||||
self.play(Create(tblA_outline), run_time=0.35)
|
|
||||||
for e in tblA_entries:
|
|
||||||
self.play(FadeIn(e), run_time=0.08)
|
|
||||||
|
|
||||||
# Alloc effect on pool + bitmask
|
|
||||||
for idx in [0, 1]:
|
|
||||||
pg = pool_pages[idx][0]
|
|
||||||
self.play(pg.animate.set_fill(GREEN, opacity=0.35), run_time=0.1)
|
|
||||||
flash = SurroundingRectangle(pool_pages[idx], color=GREEN, buff=0.04)
|
|
||||||
self.play(Create(flash), run_time=0.05)
|
|
||||||
self.play(FadeOut(flash), run_time=0.04)
|
|
||||||
|
|
||||||
update_mask("11111100", "refs = [1,1,0,0,0,0,0,0]")
|
|
||||||
self.wait(0.2)
|
|
||||||
|
|
||||||
# mapping arrows
|
|
||||||
arr_a0 = Arrow(
|
|
||||||
[tblA_phys[0][0] + 0.1, tblA_phys[0][1], 0],
|
|
||||||
[pool_pos[0][0], pool_pos[0][1] - 0.22, 0],
|
|
||||||
color=GREEN, stroke_width=1.5, buff=0.03,
|
|
||||||
max_tip_length_to_length_ratio=0.18,
|
|
||||||
)
|
|
||||||
arr_a1 = Arrow(
|
|
||||||
[tblA_phys[1][0] + 0.1, tblA_phys[1][1], 0],
|
|
||||||
[pool_pos[1][0], pool_pos[1][1] - 0.22, 0],
|
|
||||||
color=GREEN, stroke_width=1.5, buff=0.03,
|
|
||||||
max_tip_length_to_length_ratio=0.18,
|
|
||||||
)
|
|
||||||
self.play(GrowArrow(arr_a0), GrowArrow(arr_a1), run_time=0.35)
|
|
||||||
self.wait(0.3)
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
# Phase 3: Prefill (write KV into allocated pages)
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
s3a = step_msg("Phase 3: Prefill — write() KV into pages 0,1", GREEN)
|
|
||||||
|
|
||||||
write_code = VGroup(
|
|
||||||
self._small_text("def write(self, layer_id, page_table, start_pos, k, v):", 9, GREEN),
|
|
||||||
self._small_text(" first_page = start_pos // page_size # 0 // 64 = 0", 9, GREEN),
|
|
||||||
self._small_text(" last_page = (start_pos+seq_len-1)//page_size", 9, WHITE),
|
|
||||||
self._small_text(" for pi in range(first_page, last_page+1):", 9, WHITE),
|
|
||||||
self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE),
|
|
||||||
self._small_text(" chunk = min(page_start+page_size, start_pos+seq_len) -", 9, WHITE),
|
|
||||||
self._small_text(" max(page_start, start_pos)", 9, WHITE),
|
|
||||||
self._small_text(" k_cache[layer_id, phys_pages, offset:offset+chunk] = k", 9, GREEN),
|
|
||||||
)
|
|
||||||
write_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT)
|
|
||||||
wbox = Rectangle(width=5.0, height=2.3, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03)
|
|
||||||
wbox.move_to([right_x - 0.2, step_y - 0.4, 0])
|
|
||||||
write_code.move_to(wbox.get_center())
|
|
||||||
write_grp = VGroup(wbox, write_code)
|
|
||||||
self.play(Create(write_grp), run_time=0.35)
|
|
||||||
step_y -= 1.6
|
|
||||||
|
|
||||||
# Show a KV block "written" onto the pages
|
|
||||||
k_written = self._small_text("KV written", 8, GREEN)
|
|
||||||
k_written.move_to([pool_pos[0][0] + 0.5, pool_pos[0][1] + 0.6, 0])
|
|
||||||
self.play(Write(k_written))
|
|
||||||
self.wait(0.3)
|
|
||||||
fade(k_written)
|
|
||||||
|
|
||||||
s3b = step_msg("CacheView bundles cache + page_table + total_len for attention", GREEN)
|
|
||||||
cv_code = VGroup(
|
|
||||||
self._small_text("class CacheView:", 9, GREEN),
|
|
||||||
self._small_text(" def __init__(self, cache, page_table, total_len):", 9, GREEN),
|
|
||||||
self._small_text(' def write(self, layer_id, start_pos, k, v):', 9, WHITE),
|
|
||||||
self._small_text(" self._cache.write(layer_id, self._page_table, ...)", 9, WHITE),
|
|
||||||
self._small_text(' def gather(self, layer_id):', 9, WHITE),
|
|
||||||
self._small_text(" for pi in range(page_table.size(1)):", 9, WHITE),
|
|
||||||
self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE),
|
|
||||||
self._small_text(" k_parts.append(k_cache[layer_id, phys_pages])", 9, WHITE),
|
|
||||||
self._small_text(" k = torch.cat(k_parts, dim=1)", 9, GREEN),
|
|
||||||
)
|
|
||||||
cv_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT)
|
|
||||||
cvbox = Rectangle(width=5.0, height=2.5, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03)
|
|
||||||
cvbox.move_to([right_x - 0.2, step_y - 0.6, 0])
|
|
||||||
cv_code.move_to(cvbox.get_center())
|
|
||||||
cv_grp = VGroup(cvbox, cv_code)
|
|
||||||
self.play(Create(cv_grp), run_time=0.35)
|
|
||||||
step_y -= 1.8
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
# Request B arrives
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
s4 = step_msg("Phase 2: Refill — Request B arrives (prompt_len=90)", ORANGE)
|
|
||||||
s4b = step_msg("alloc_n(2) → pages 2, 3", ORANGE)
|
|
||||||
|
|
||||||
tblB_x = -4.4
|
|
||||||
tblB_y = -1.9
|
|
||||||
tblB_outline, tblB_entries, tblB_phys, tblB_rect = self._make_ptable(
|
|
||||||
tblB_x, tblB_y,
|
|
||||||
[("0", "2"), ("1", "3")],
|
|
||||||
ORANGE, "Task B page_table"
|
|
||||||
)
|
|
||||||
self.play(Create(tblB_outline), run_time=0.3)
|
|
||||||
for e in tblB_entries:
|
|
||||||
self.play(FadeIn(e), run_time=0.08)
|
|
||||||
|
|
||||||
for idx in [2, 3]:
|
|
||||||
pg = pool_pages[idx][0]
|
|
||||||
self.play(pg.animate.set_fill(ORANGE, opacity=0.35), run_time=0.1)
|
|
||||||
flash = SurroundingRectangle(pool_pages[idx], color=ORANGE, buff=0.04)
|
|
||||||
self.play(Create(flash), run_time=0.05)
|
|
||||||
self.play(FadeOut(flash), run_time=0.04)
|
|
||||||
|
|
||||||
update_mask("11110000", "refs = [1,1,1,1,0,0,0,0]")
|
|
||||||
|
|
||||||
arr_b0 = Arrow(
|
|
||||||
[tblB_phys[0][0] + 0.1, tblB_phys[0][1], 0],
|
|
||||||
[pool_pos[2][0], pool_pos[2][1] - 0.22, 0],
|
|
||||||
color=ORANGE, stroke_width=1.5, buff=0.03,
|
|
||||||
max_tip_length_to_length_ratio=0.18,
|
|
||||||
)
|
|
||||||
arr_b1 = Arrow(
|
|
||||||
[tblB_phys[1][0] + 0.1, tblB_phys[1][1], 0],
|
|
||||||
[pool_pos[3][0], pool_pos[3][1] - 0.22, 0],
|
|
||||||
color=ORANGE, stroke_width=1.5, buff=0.03,
|
|
||||||
max_tip_length_to_length_ratio=0.18,
|
|
||||||
)
|
|
||||||
self.play(GrowArrow(arr_b0), GrowArrow(arr_b1), run_time=0.35)
|
|
||||||
self.wait(0.3)
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
# Phase 4: Decode — on-demand page growth
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
s5 = step_msg("Phase 4: Decode — Task A generates more tokens", PINK)
|
|
||||||
s5b = step_msg("total_len grows: 120 → 180 tokens", PINK)
|
|
||||||
s5c = step_msg("_n_pages_for(180) = (180+63)//64 = 243//64 = 3", PINK)
|
|
||||||
|
|
||||||
# _maybe_alloc_page logic
|
|
||||||
map_code = VGroup(
|
|
||||||
self._small_text("def _maybe_alloc_page(self, task, pos):", 9, PINK),
|
|
||||||
self._small_text(" needed = _n_pages_for(pos + 1) # _n_pages_for(181) = 3", 9, PINK),
|
|
||||||
self._small_text(" while task.n_pages < needed: # 2 < 3", 9, PINK),
|
|
||||||
self._small_text(" p = self.page_cache.alloc() # alloc page 4", 9, PINK),
|
|
||||||
self._small_text(" task.page_table.append(p)", 9, PINK),
|
|
||||||
self._small_text(" task.n_pages += 1 # 2 → 3", 9, PINK),
|
|
||||||
self._small_text("", 5, PINK),
|
|
||||||
self._small_text("page_table_A: [0, 1] → [0, 1, 4]", 9, WHITE),
|
|
||||||
)
|
|
||||||
map_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT)
|
|
||||||
mbox = Rectangle(width=5.0, height=2.2, color=PINK, stroke_width=1.0, fill_opacity=0.03)
|
|
||||||
mbox.move_to([right_x - 0.2, step_y - 0.2, 0])
|
|
||||||
map_code.move_to(mbox.get_center())
|
|
||||||
map_grp = VGroup(mbox, map_code)
|
|
||||||
self.play(Create(map_grp), run_time=0.35)
|
|
||||||
step_y -= 1.6
|
|
||||||
|
|
||||||
# Alloc page 4
|
|
||||||
idx = 4
|
|
||||||
pg = pool_pages[idx][0]
|
|
||||||
self.play(pg.animate.set_fill(PINK, opacity=0.35), run_time=0.1)
|
|
||||||
flash = SurroundingRectangle(pool_pages[idx], color=PINK, buff=0.04)
|
|
||||||
self.play(Create(flash), run_time=0.05)
|
|
||||||
self.play(FadeOut(flash), run_time=0.04)
|
|
||||||
|
|
||||||
update_mask("11100000", "refs = [1,1,1,1,1,0,0,0]")
|
|
||||||
|
|
||||||
# Expand Table A: add row 3
|
|
||||||
new_bottom = tblA_y - 0.40 * 4 # 3 data rows + 1 header
|
|
||||||
line_3 = Line(
|
|
||||||
[tblA_x, tblA_y - 0.40 * 3, 0],
|
|
||||||
[tblA_x + 2.6, tblA_y - 0.40 * 3, 0],
|
|
||||||
color=GREEN, stroke_width=0.8,
|
|
||||||
)
|
|
||||||
lc3 = [tblA_x + 2.6 * 0.20, tblA_y - 0.40 * 3 - 0.20, 0]
|
|
||||||
pc3 = [tblA_x + 2.6 * 0.40 + (2.6 - 2.6 * 0.40) / 2, tblA_y - 0.40 * 3 - 0.20, 0]
|
|
||||||
lt3 = self._small_text("2", 11, PINK).move_to(lc3)
|
|
||||||
pt3 = self._small_text("4", 11, PINK).move_to(pc3)
|
|
||||||
self.play(Create(line_3), run_time=0.08)
|
|
||||||
self.play(FadeIn(lt3), FadeIn(pt3), run_time=0.08)
|
|
||||||
|
|
||||||
arr_c = Arrow(
|
|
||||||
[pc3[0] + 0.1, pc3[1], 0],
|
|
||||||
[pool_pos[4][0], pool_pos[4][1] - 0.22, 0],
|
|
||||||
color=PINK, stroke_width=1.5, buff=0.03,
|
|
||||||
max_tip_length_to_length_ratio=0.18,
|
|
||||||
)
|
|
||||||
self.play(GrowArrow(arr_c), run_time=0.2)
|
|
||||||
self.wait(0.3)
|
|
||||||
|
|
||||||
# Highlight: page_table list conversion
|
|
||||||
pt_list_old = self._small_text("page_table_A: [0, 1] (2 pages → can hold 128 tokens)", 10, GREEN)
|
|
||||||
pt_list_old.move_to([-4.4, tblB_y - 1.35, 0])
|
|
||||||
self.play(Write(pt_list_old))
|
|
||||||
|
|
||||||
pt_list_new = self._small_text("page_table_A: [0, 1, 4] (3 pages → can hold 192 tokens)", 10, GREEN)
|
|
||||||
pt_list_new.move_to([-4.4, tblB_y - 1.7, 0])
|
|
||||||
self.play(Write(pt_list_new))
|
|
||||||
self.wait(0.4)
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
# Task A finished → free()
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
s6 = step_msg("Task A done → free pages 0, 1, 4", YELLOW)
|
|
||||||
s6b = step_msg("free() → refs[idx] -= 1; if refs[idx]==0: mask |= 1<<idx", YELLOW)
|
|
||||||
|
|
||||||
free_code = VGroup(
|
|
||||||
self._small_text("def free(self, idx):", 9, YELLOW),
|
|
||||||
self._small_text(" self._refs[idx] -= 1", 9, YELLOW),
|
|
||||||
self._small_text(" if self._refs[idx] == 0:", 9, YELLOW),
|
|
||||||
self._small_text(" self._free_mask |= 1 << idx", 9, YELLOW),
|
|
||||||
)
|
|
||||||
free_code.arrange(DOWN, buff=0.06, aligned_edge=LEFT)
|
|
||||||
fbox = Rectangle(width=4.2, height=1.3, color=YELLOW, stroke_width=1.0, fill_opacity=0.03)
|
|
||||||
fbox.move_to([right_x - 0.2, step_y - 0.2, 0])
|
|
||||||
free_code.move_to(fbox.get_center())
|
|
||||||
free_grp = VGroup(fbox, free_code)
|
|
||||||
self.play(Create(free_grp), run_time=0.3)
|
|
||||||
step_y -= 0.9
|
|
||||||
|
|
||||||
for idx in [0, 1]:
|
|
||||||
pg = pool_pages[idx][0]
|
pg = pool_pages[idx][0]
|
||||||
self.play(pg.animate.set_fill(GRAY, opacity=0.12), run_time=0.08)
|
self.play(pg.animate.set_fill(GRAY, opacity=0.12), run_time=0.08)
|
||||||
flash = SurroundingRectangle(pool_pages[idx], color=YELLOW, buff=0.04)
|
flash = SurroundingRectangle(pool_pages[idx], color=YELLOW, buff=0.04)
|
||||||
self.play(Create(flash), run_time=0.06)
|
self.play(Create(flash), run_time=0.06)
|
||||||
self.play(FadeOut(flash), run_time=0.04)
|
self.play(FadeOut(flash), run_time=0.04)
|
||||||
|
set_mask("10010011")
|
||||||
|
|
||||||
update_mask("11100011", "refs = [0,0,1,1,1,0,0,0]")
|
# ── B expands (reuse) ──
|
||||||
|
alloc(0, ORANGE); set_mask("10010010")
|
||||||
|
b.expand(0)
|
||||||
|
|
||||||
for idx in [4]:
|
|
||||||
pg = pool_pages[idx][0]
|
|
||||||
self.play(pg.animate.set_fill(GRAY, opacity=0.12), run_time=0.08)
|
|
||||||
flash = SurroundingRectangle(pool_pages[idx], color=YELLOW, buff=0.04)
|
|
||||||
self.play(Create(flash), run_time=0.06)
|
|
||||||
self.play(FadeOut(flash), run_time=0.04)
|
|
||||||
|
|
||||||
update_mask("11110011", "refs = [0,0,1,1,0,0,0,0]")
|
|
||||||
|
|
||||||
# Show Cleanup phase removes finished tasks
|
|
||||||
s_cleanup = self._small_text("Phase 1: Cleanup — Task A removed from active list", 11, GRAY)
|
|
||||||
s_cleanup.move_to([right_x, step_y, 0])
|
|
||||||
step_y -= 0.3
|
|
||||||
self.play(Write(s_cleanup))
|
|
||||||
|
|
||||||
pt_list_done = self._small_text("page_table_A cleared. Frame 0,1,4 returned to pool.", 10, GRAY)
|
|
||||||
pt_list_done.move_to([-4.4, tblB_y - 2.1, 0])
|
|
||||||
self.play(Write(pt_list_done))
|
|
||||||
self.wait(0.5)
|
self.wait(0.5)
|
||||||
|
s = Text("Page-table-indirected, O(1) alloc/free, on-demand growth",
|
||||||
# ═══════════════════════════════════════════
|
font_size=12, color=GREEN)
|
||||||
# gather() demo — remaining Task B reads KV
|
s.move_to([-3.0, -2.0, 0])
|
||||||
# ═══════════════════════════════════════════
|
self.play(Write(s))
|
||||||
s7 = step_msg("Task B continues — gather() reads KV for attention", BLUE)
|
|
||||||
|
|
||||||
gather_code = VGroup(
|
|
||||||
self._small_text("def gather(self, layer_id):", 9, BLUE),
|
|
||||||
self._small_text(" for pi in range(page_table.size(1)):", 9, BLUE),
|
|
||||||
self._small_text(" phys_pages = page_table[:, pi]", 9, BLUE),
|
|
||||||
self._small_text(" k_parts.append(k_cache[layer_id, phys_pages])", 9, BLUE),
|
|
||||||
self._small_text(" k = torch.cat(k_parts, dim=1)", 9, BLUE),
|
|
||||||
self._small_text(" return k, v", 9, BLUE),
|
|
||||||
self._small_text("", 5, BLUE),
|
|
||||||
self._small_text("gather reads pages [2, 3] → token positions [128..255]", 9, WHITE),
|
|
||||||
)
|
|
||||||
gather_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT)
|
|
||||||
gbox = Rectangle(width=5.0, height=2.0, color=BLUE, stroke_width=1.0, fill_opacity=0.03)
|
|
||||||
gbox.move_to([right_x - 0.2, step_y - 0.4, 0])
|
|
||||||
gather_code.move_to(gbox.get_center())
|
|
||||||
gather_grp = VGroup(gbox, gather_code)
|
|
||||||
self.play(Create(gather_grp), run_time=0.35)
|
|
||||||
step_y -= 1.4
|
|
||||||
|
|
||||||
# Highlight B's remaining pages
|
|
||||||
for idx in [2, 3]:
|
|
||||||
flash = SurroundingRectangle(pool_pages[idx], color=BLUE, buff=0.04)
|
|
||||||
self.play(Create(flash), run_time=0.06)
|
|
||||||
self.play(FadeOut(flash), run_time=0.04)
|
|
||||||
self.wait(0.3)
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
# Summary
|
|
||||||
# ═══════════════════════════════════════════
|
|
||||||
self.wait(0.5)
|
|
||||||
summary = self._small_text("Paged KV Cache — page-table-indirected, O(1) alloc/free, on-demand growth", 20, GREEN)
|
|
||||||
summary.to_edge(DOWN, buff=0.4)
|
|
||||||
self.play(Write(summary))
|
|
||||||
|
|
||||||
benefits = VGroup(
|
|
||||||
self._small_text("✓ No per-request KV pre-allocation — pages allocated on demand", 14, GREEN),
|
|
||||||
self._small_text("✓ Page table decouples logical position from physical storage", 14, GREEN),
|
|
||||||
self._small_text("✓ Ref-counted free → safe concurrent release across tasks", 14, GREEN),
|
|
||||||
self._small_text("✓ Bitmask O(1) alloc/free — no fragmentation", 14, GREEN),
|
|
||||||
)
|
|
||||||
benefits.arrange(DOWN, buff=0.08, aligned_edge=LEFT)
|
|
||||||
benefits.move_to([-5.0, -1.5, 0])
|
|
||||||
self.play(Write(benefits), run_time=0.5)
|
|
||||||
|
|
||||||
# Final mask
|
|
||||||
summary_mask = self._small_text(
|
|
||||||
"Final state: free_mask=11110011 free frames=0,1,4 in-use=2,3 total=8",
|
|
||||||
10, GRAY
|
|
||||||
)
|
|
||||||
summary_mask.next_to(benefits, DOWN, buff=0.25, aligned_edge=LEFT)
|
|
||||||
self.play(Write(summary_mask))
|
|
||||||
|
|
||||||
self.wait(2)
|
self.wait(2)
|
||||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue