diff --git a/paged_cache.py b/paged_cache.py index 4f337b2..3ed81d1 100644 --- a/paged_cache.py +++ b/paged_cache.py @@ -1,129 +1,94 @@ -"""AstrAI promo: Paged KV Cache — matching astrai/inference/cache.py & scheduler.py.""" +"""AstrAI promo: Paged KV Cache — astrai/inference/cache.py & scheduler.py.""" from manim import * Text.set_default(font="Times New Roman") -class PrefixCache(Scene): - """Animates PagedCache exact logic: alloc→write→free with real code details.""" +class _TaskRow: + """Manages one task's logical-page row: label, blocks, arrows.""" - def _small_text(self, text, size=10, color=GRAY, **kwargs): + COLS = [-4.0, -3.2, -2.4, -1.6] + + def __init__(self, scene, label, color, y, pool_pos, pool_y): + self.scene = scene + self.color = color + self.y = y + self.pool_pos = pool_pos + self.pool_y = pool_y + self._next_col = 0 + self.blocks = VGroup() + self.arrows = VGroup() + + lbl = scene._small(label, 11, color, weight=BOLD) + lbl.move_to([-5.2, y, 0]) + self._label = lbl + + self.blocks.add(lbl) + + def arrive(self, *phys_idxs): + self.scene.play(Write(self._label)) + for pid in phys_idxs: + self._add(pid, expand=False) + + def expand(self, phys_idx): + self._add(phys_idx, expand=True) + + def _add(self, phys_idx, expand): + col = self._next_col + self._next_col += 1 + x = self.COLS[col] + pos = np.array([x, self.y, 0]) + + pb = self.scene._lp_box(pos, str(col), self.color) + self.blocks.add(pb) + + arr = Arrow( + [x, self.y + 0.19, 0], + [self.pool_pos[phys_idx][0], self.pool_y - 0.22, 0], + color=self.color, stroke_width=1.5, buff=0.03, + max_tip_length_to_length_ratio=0.12, + ) + self.arrows.add(arr) + + if expand: + target = pb.copy() + pb.scale(0) + self.scene.add(pb) + self.scene.play(Transform(pb, target), GrowArrow(arr), run_time=0.3) + else: + self.scene.play(FadeIn(pb, scale=0.5), GrowArrow(arr), run_time=0.12) + + def finish(self): + self.scene.play(FadeOut(self.blocks), FadeOut(self.arrows)) + + +class PagedCache(Scene): + def _small(self, text, size=10, color=GRAY, **kwargs): return Text(text, font_size=size, color=color, **kwargs) def _page_box(self, pos, label, color, sz=0.44): s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6) s.move_to(pos) - lbl = self._small_text(label, 10, color).move_to(pos) + lbl = self._small(label, 10, color).move_to(pos) return VGroup(s, lbl) - def _make_ptable(self, x, y, rows, color, label): - """Draw OS-style page table. Returns (outline, entry_group, phy_cell_centers).""" - n = len(rows) - w = 2.6 - h = 0.40 - vx = w * 0.40 - - group = VGroup() - - # Label - lbl = self._small_text(label, 12, color, weight=BOLD) - lbl.move_to([x + w / 2, y + 0.15, 0]) - group.add(lbl) - - # Outer rect (header+data rows) - r = Rectangle(width=w, height=h * (n + 1), color=color, stroke_width=1.8, fill_opacity=0.06) - r.move_to([x + w / 2, y - h * (n + 1) / 2, 0]) - group.add(r) - - # H dividers - for ri in range(1, n + 1): - ly = y - ri * h - group.add(Line([x, ly, 0], [x + w, ly, 0], color=color, stroke_width=0.8)) - - # V divider - group.add(Line([x + vx, y, 0], [x + vx, y - h * (n + 1), 0], color=color, stroke_width=0.8)) - - # Column headers - lh = self._small_text("logical", 10, color, weight=BOLD) - lh.move_to([x + vx / 2, y - h / 2, 0]) - ph = self._small_text("physical", 10, color, weight=BOLD) - ph.move_to([x + vx + (w - vx) / 2, y - h / 2, 0]) - group.add(lh, ph) - - # Row entries - entries = VGroup() - phy_centers = [] - for i, (lg, phy) in enumerate(rows): - cy = y - h / 2 - (i + 1) * h - lt = self._small_text(lg, 11, color) - lt.move_to([x + vx / 2, cy, 0]) - pt = self._small_text(phy, 11, color) - pt.move_to([x + vx + (w - vx) / 2, cy, 0]) - entries.add(VGroup(lt, pt)) - phy_centers.append([x + vx + (w - vx) / 2, cy, 0]) - - return group, entries, phy_centers, r + def _lp_box(self, pos, label, color, sz=0.38): + s = RoundedRectangle(width=sz, height=sz, corner_radius=0.06, + color=color, fill_opacity=0.22, stroke_width=1.6) + s.move_to(pos) + lbl = self._small(label, 12, color).move_to(pos) + return VGroup(s, lbl) def construct(self): - # ── Scene setup ── - title = self._small_text("Paged KV Cache — astrai/inference/cache.py", 26, BLUE) + title = Text("Paged KV Cache", font_size=20, color=BLUE) title.to_edge(UP, buff=0.15) self.play(Write(title)) - self.wait(0.1) - right_x = 4.5 - step_y = 3.0 - - def step_msg(text, color=YELLOW): - nonlocal step_y - m = self._small_text(text, 12, color) - m.move_to([right_x, step_y, 0]) - step_y -= 0.55 - self.play(Write(m)) - return m - - def fade(m): - self.play(FadeOut(m)) - - # ═══════════════════════════════════════════ - # Phase 0: Initialize PagedCache - # ═══════════════════════════════════════════ - s0 = step_msg("PagedCache(max_batch=8, page_size=64, head_dim=64, n_kv_heads=4)") - - # Show constructor signature - ctor = self._small_text( - 'PagedCache(n_layers=24, n_pages=8, page_size=64,\n' - ' n_kv_heads=4, head_dim=64, device=cuda, dtype=bfloat16)', - 9, GRAY - ) - ctor.move_to([-5.5, 2.4, 0]) - self.play(Write(ctor)) - - # Show the tensor shape - tensor_shape = self._small_text( - 'k_cache: [24, 8, 64, 4, 64] v_cache: [24, 8, 64, 4, 64]', - 8, GRAY - ) - tensor_shape.move_to([-5.5, 2.05, 0]) - self.play(Write(tensor_shape)) - - self.wait(0.3) - fade(s0) - # Keep ctor & tensor_shape visible - - # ═══════════════════════════════════════════ - # Physical page pool - # ═══════════════════════════════════════════ - s1 = step_msg("8 page frames, refs[8] = [0,0,0,0,0,0,0,0], free_mask = (1<<8)-1 = 0xFF") - - pool_y = 1.45 - pool_x0 = -3.8 - pool_sp = 0.68 - pool_pages = [] - pool_pos = [] + pool_y = 1.45; pool_x0 = -3.8; sp = 0.68 + pool_pages = []; pool_pos = [] for i in range(8): - x = pool_x0 + i * pool_sp + x = pool_x0 + i * sp pos = np.array([x, pool_y, 0]) pool_pos.append(pos) pb = self._page_box(pos, str(i), GRAY) @@ -131,394 +96,69 @@ class PrefixCache(Scene): self.play(FadeIn(pb, scale=0.5), run_time=0.04) self.wait(0.1) - # Bracket - brack = Brace(VGroup(*[p[0] for p in pool_pages]), DOWN, buff=0.05) - blbl = Text("page frames [0..7] each holds 64 KV slots", font_size=10, color=GRAY) - blbl.next_to(brack, DOWN, buff=0.02) - self.play(Create(brack), Write(blbl), run_time=0.25) + plbl = self._small("page frames [0..7]", 9, GRAY) + plbl.next_to(pool_pages[0][0], DOWN, buff=0.25).shift(LEFT * 0.3) + self.play(Write(plbl)) - # Free bitmask — with bit position labels - mask_y = 0.55 - bit_labels = VGroup() - for i in range(8): - bl = self._small_text(f"bit{i}", 7, DARK_GRAY) - bl.move_to([pool_x0 + i * pool_sp, mask_y + 0.25, 0]) - bit_labels.add(bl) - self.play(Write(bit_labels), run_time=0.2) - - mask = self._small_text("11111111 (1 = free, 0 = alloc)", 11, GRAY) - mask.move_to([-1.5, mask_y, 0]) + mask = self._small("free: 11111111", 10, GRAY) + mask.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT) self.play(Write(mask)) - self.wait(0.15) - refs_lbl = self._small_text("refs = [0,0,0,0,0,0,0,0]", 9, GRAY) - refs_lbl.move_to([-1.5, mask_y - 0.3, 0]) - self.play(Write(refs_lbl)) + def alloc(idx, color): + pg = pool_pages[idx][0] + self.play(pg.animate.set_fill(color, opacity=0.35), run_time=0.1) + flash = SurroundingRectangle(pool_pages[idx], color=color, buff=0.04) + self.play(Create(flash), run_time=0.05) + self.play(FadeOut(flash), run_time=0.04) - def update_mask(bits, desc): - m2 = self._small_text(f"{bits} (1 = free, 0 = alloc)", 11, GRAY) - m2.move_to([-1.5, mask_y, 0]) + def set_mask(bits): + m2 = self._small(f"free: {bits}", 10, GRAY) + m2.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT) self.play(Transform(mask, m2)) - if desc: - self.play(Transform(refs_lbl, self._small_text(desc, 9, GRAY).move_to([-1.5, mask_y - 0.3, 0]))) - fade(s1) + a_y = 0.25; b_y = -0.45; c_y = -1.15 - # ═══════════════════════════════════════════ - # Phase 1: Cleanup — nothing to do initially - # ═══════════════════════════════════════════ - s_phase = self._small_text("Phase 1: Cleanup (no finished tasks)", 11, GRAY) - s_phase.move_to([right_x, step_y, 0]) - step_y -= 0.4 - self.play(Write(s_phase)) - self.wait(0.2) + # ── A arrives ── + alloc(0, GREEN); alloc(1, GREEN); set_mask("11111100") + a = _TaskRow(self, "A", GREEN, a_y, pool_pos, pool_y) + a.arrive(0, 1) - # ═══════════════════════════════════════════ - # Phase 2: Refill — Request A arrives - # ═══════════════════════════════════════════ - step_y2 = step_y - s2a = step_msg("Phase 2: Refill — Request A arrives (prompt_len=120 tokens)", GREEN) - s2b = step_msg("_n_pages_for(120) = (120 + 64 - 1) // 64 = 183 // 64 = 2", GREEN) + # ── A expands 1 ── + alloc(4, GREEN); set_mask("11101100") + a.expand(4) - calc_box = Rectangle(width=4.8, height=1.0, color=GREEN_E, stroke_width=1.2, fill_opacity=0.05) - calc_box.move_to([right_x - 0.3, step_y - 0.1, 0]) - calc_lines = VGroup( - self._small_text("def _n_pages_for(n_tokens):", 9, GREEN), - self._small_text(" return (n_tokens + page_size - 1) // page_size", 9, GREEN), - self._small_text("_n_pages_for(120) = (120 + 64 - 1) // 64 = 2", 9, WHITE), - ) - calc_lines.arrange(DOWN, buff=0.08, aligned_edge=LEFT) - calc_lines.move_to(calc_box.get_center()) - calc_grp = VGroup(calc_box, calc_lines) - self.play(Create(calc_grp), run_time=0.35) - step_y -= 0.6 + # ── B arrives ── + alloc(2, ORANGE); alloc(3, ORANGE); set_mask("11100000") + b = _TaskRow(self, "B", ORANGE, b_y, pool_pos, pool_y) + b.arrive(2, 3) - # alloc() in action - s2c = step_msg("alloc_n(2) → calls alloc() twice", GREEN) + # ── C arrives ── + alloc(5, BLUE); alloc(6, BLUE); set_mask("10000000") + c = _TaskRow(self, "C", BLUE, c_y, pool_pos, pool_y) + c.arrive(5, 6) - alloc_code = VGroup( - self._small_text("def alloc(self) -> int:", 9, GREEN), - self._small_text(" lsb = self._free_mask & -self._free_mask", 9, GREEN), - self._small_text(" if lsb == 0: return -1", 9, GREEN), - self._small_text(" idx = lsb.bit_length() - 1", 9, GREEN), - self._small_text(" self._free_mask ^= lsb", 9, GREEN), - self._small_text(" self._refs[idx] = 1", 9, GREEN), - self._small_text(" return idx", 9, GREEN), - self._small_text("", 6, GREEN), - self._small_text("1st alloc: free_mask=11111111, lsb=1, idx=0 → page 0", 9, WHITE), - self._small_text("2nd alloc: free_mask=11111110, lsb=2, idx=1 → page 1", 9, WHITE), - ) - alloc_code.arrange(DOWN, buff=0.06, aligned_edge=LEFT) - acbox = Rectangle(width=5.0, height=2.6, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) - acbox.move_to([right_x - 0.2, step_y - 0.8, 0]) - alloc_code.move_to(acbox.get_center()) - alloc_grp = VGroup(acbox, alloc_code) - self.play(Create(alloc_grp), run_time=0.4) - step_y -= 1.8 + # ── A expands 2 ── + alloc(7, GREEN); set_mask("00000000") + a.expand(7) - # Now draw page table A (2 rows) - tblA_x = -4.4 - tblA_y = -0.35 - tblA_outline, tblA_entries, tblA_phys, tblA_rect = self._make_ptable( - tblA_x, tblA_y, - [("0", "0"), ("1", "1")], - GREEN, "Task A page_table" - ) - self.play(Create(tblA_outline), run_time=0.35) - for e in tblA_entries: - self.play(FadeIn(e), run_time=0.08) - - # Alloc effect on pool + bitmask - for idx in [0, 1]: - pg = pool_pages[idx][0] - self.play(pg.animate.set_fill(GREEN, opacity=0.35), run_time=0.1) - flash = SurroundingRectangle(pool_pages[idx], color=GREEN, buff=0.04) - self.play(Create(flash), run_time=0.05) - self.play(FadeOut(flash), run_time=0.04) - - update_mask("11111100", "refs = [1,1,0,0,0,0,0,0]") - self.wait(0.2) - - # mapping arrows - arr_a0 = Arrow( - [tblA_phys[0][0] + 0.1, tblA_phys[0][1], 0], - [pool_pos[0][0], pool_pos[0][1] - 0.22, 0], - color=GREEN, stroke_width=1.5, buff=0.03, - max_tip_length_to_length_ratio=0.18, - ) - arr_a1 = Arrow( - [tblA_phys[1][0] + 0.1, tblA_phys[1][1], 0], - [pool_pos[1][0], pool_pos[1][1] - 0.22, 0], - color=GREEN, stroke_width=1.5, buff=0.03, - max_tip_length_to_length_ratio=0.18, - ) - self.play(GrowArrow(arr_a0), GrowArrow(arr_a1), run_time=0.35) - self.wait(0.3) - - # ═══════════════════════════════════════════ - # Phase 3: Prefill (write KV into allocated pages) - # ═══════════════════════════════════════════ - s3a = step_msg("Phase 3: Prefill — write() KV into pages 0,1", GREEN) - - write_code = VGroup( - self._small_text("def write(self, layer_id, page_table, start_pos, k, v):", 9, GREEN), - self._small_text(" first_page = start_pos // page_size # 0 // 64 = 0", 9, GREEN), - self._small_text(" last_page = (start_pos+seq_len-1)//page_size", 9, WHITE), - self._small_text(" for pi in range(first_page, last_page+1):", 9, WHITE), - self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE), - self._small_text(" chunk = min(page_start+page_size, start_pos+seq_len) -", 9, WHITE), - self._small_text(" max(page_start, start_pos)", 9, WHITE), - self._small_text(" k_cache[layer_id, phys_pages, offset:offset+chunk] = k", 9, GREEN), - ) - write_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) - wbox = Rectangle(width=5.0, height=2.3, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) - wbox.move_to([right_x - 0.2, step_y - 0.4, 0]) - write_code.move_to(wbox.get_center()) - write_grp = VGroup(wbox, write_code) - self.play(Create(write_grp), run_time=0.35) - step_y -= 1.6 - - # Show a KV block "written" onto the pages - k_written = self._small_text("KV written", 8, GREEN) - k_written.move_to([pool_pos[0][0] + 0.5, pool_pos[0][1] + 0.6, 0]) - self.play(Write(k_written)) - self.wait(0.3) - fade(k_written) - - s3b = step_msg("CacheView bundles cache + page_table + total_len for attention", GREEN) - cv_code = VGroup( - self._small_text("class CacheView:", 9, GREEN), - self._small_text(" def __init__(self, cache, page_table, total_len):", 9, GREEN), - self._small_text(' def write(self, layer_id, start_pos, k, v):', 9, WHITE), - self._small_text(" self._cache.write(layer_id, self._page_table, ...)", 9, WHITE), - self._small_text(' def gather(self, layer_id):', 9, WHITE), - self._small_text(" for pi in range(page_table.size(1)):", 9, WHITE), - self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE), - self._small_text(" k_parts.append(k_cache[layer_id, phys_pages])", 9, WHITE), - self._small_text(" k = torch.cat(k_parts, dim=1)", 9, GREEN), - ) - cv_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) - cvbox = Rectangle(width=5.0, height=2.5, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) - cvbox.move_to([right_x - 0.2, step_y - 0.6, 0]) - cv_code.move_to(cvbox.get_center()) - cv_grp = VGroup(cvbox, cv_code) - self.play(Create(cv_grp), run_time=0.35) - step_y -= 1.8 - - # ═══════════════════════════════════════════ - # Request B arrives - # ═══════════════════════════════════════════ - s4 = step_msg("Phase 2: Refill — Request B arrives (prompt_len=90)", ORANGE) - s4b = step_msg("alloc_n(2) → pages 2, 3", ORANGE) - - tblB_x = -4.4 - tblB_y = -1.9 - tblB_outline, tblB_entries, tblB_phys, tblB_rect = self._make_ptable( - tblB_x, tblB_y, - [("0", "2"), ("1", "3")], - ORANGE, "Task B page_table" - ) - self.play(Create(tblB_outline), run_time=0.3) - for e in tblB_entries: - self.play(FadeIn(e), run_time=0.08) - - for idx in [2, 3]: - pg = pool_pages[idx][0] - self.play(pg.animate.set_fill(ORANGE, opacity=0.35), run_time=0.1) - flash = SurroundingRectangle(pool_pages[idx], color=ORANGE, buff=0.04) - self.play(Create(flash), run_time=0.05) - self.play(FadeOut(flash), run_time=0.04) - - update_mask("11110000", "refs = [1,1,1,1,0,0,0,0]") - - arr_b0 = Arrow( - [tblB_phys[0][0] + 0.1, tblB_phys[0][1], 0], - [pool_pos[2][0], pool_pos[2][1] - 0.22, 0], - color=ORANGE, stroke_width=1.5, buff=0.03, - max_tip_length_to_length_ratio=0.18, - ) - arr_b1 = Arrow( - [tblB_phys[1][0] + 0.1, tblB_phys[1][1], 0], - [pool_pos[3][0], pool_pos[3][1] - 0.22, 0], - color=ORANGE, stroke_width=1.5, buff=0.03, - max_tip_length_to_length_ratio=0.18, - ) - self.play(GrowArrow(arr_b0), GrowArrow(arr_b1), run_time=0.35) - self.wait(0.3) - - # ═══════════════════════════════════════════ - # Phase 4: Decode — on-demand page growth - # ═══════════════════════════════════════════ - s5 = step_msg("Phase 4: Decode — Task A generates more tokens", PINK) - s5b = step_msg("total_len grows: 120 → 180 tokens", PINK) - s5c = step_msg("_n_pages_for(180) = (180+63)//64 = 243//64 = 3", PINK) - - # _maybe_alloc_page logic - map_code = VGroup( - self._small_text("def _maybe_alloc_page(self, task, pos):", 9, PINK), - self._small_text(" needed = _n_pages_for(pos + 1) # _n_pages_for(181) = 3", 9, PINK), - self._small_text(" while task.n_pages < needed: # 2 < 3", 9, PINK), - self._small_text(" p = self.page_cache.alloc() # alloc page 4", 9, PINK), - self._small_text(" task.page_table.append(p)", 9, PINK), - self._small_text(" task.n_pages += 1 # 2 → 3", 9, PINK), - self._small_text("", 5, PINK), - self._small_text("page_table_A: [0, 1] → [0, 1, 4]", 9, WHITE), - ) - map_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) - mbox = Rectangle(width=5.0, height=2.2, color=PINK, stroke_width=1.0, fill_opacity=0.03) - mbox.move_to([right_x - 0.2, step_y - 0.2, 0]) - map_code.move_to(mbox.get_center()) - map_grp = VGroup(mbox, map_code) - self.play(Create(map_grp), run_time=0.35) - step_y -= 1.6 - - # Alloc page 4 - idx = 4 - pg = pool_pages[idx][0] - self.play(pg.animate.set_fill(PINK, opacity=0.35), run_time=0.1) - flash = SurroundingRectangle(pool_pages[idx], color=PINK, buff=0.04) - self.play(Create(flash), run_time=0.05) - self.play(FadeOut(flash), run_time=0.04) - - update_mask("11100000", "refs = [1,1,1,1,1,0,0,0]") - - # Expand Table A: add row 3 - new_bottom = tblA_y - 0.40 * 4 # 3 data rows + 1 header - line_3 = Line( - [tblA_x, tblA_y - 0.40 * 3, 0], - [tblA_x + 2.6, tblA_y - 0.40 * 3, 0], - color=GREEN, stroke_width=0.8, - ) - lc3 = [tblA_x + 2.6 * 0.20, tblA_y - 0.40 * 3 - 0.20, 0] - pc3 = [tblA_x + 2.6 * 0.40 + (2.6 - 2.6 * 0.40) / 2, tblA_y - 0.40 * 3 - 0.20, 0] - lt3 = self._small_text("2", 11, PINK).move_to(lc3) - pt3 = self._small_text("4", 11, PINK).move_to(pc3) - self.play(Create(line_3), run_time=0.08) - self.play(FadeIn(lt3), FadeIn(pt3), run_time=0.08) - - arr_c = Arrow( - [pc3[0] + 0.1, pc3[1], 0], - [pool_pos[4][0], pool_pos[4][1] - 0.22, 0], - color=PINK, stroke_width=1.5, buff=0.03, - max_tip_length_to_length_ratio=0.18, - ) - self.play(GrowArrow(arr_c), run_time=0.2) - self.wait(0.3) - - # Highlight: page_table list conversion - pt_list_old = self._small_text("page_table_A: [0, 1] (2 pages → can hold 128 tokens)", 10, GREEN) - pt_list_old.move_to([-4.4, tblB_y - 1.35, 0]) - self.play(Write(pt_list_old)) - - pt_list_new = self._small_text("page_table_A: [0, 1, 4] (3 pages → can hold 192 tokens)", 10, GREEN) - pt_list_new.move_to([-4.4, tblB_y - 1.7, 0]) - self.play(Write(pt_list_new)) - self.wait(0.4) - - # ═══════════════════════════════════════════ - # Task A finished → free() - # ═══════════════════════════════════════════ - s6 = step_msg("Task A done → free pages 0, 1, 4", YELLOW) - s6b = step_msg("free() → refs[idx] -= 1; if refs[idx]==0: mask |= 1<