"""AstrAI promo: Paged KV Cache — matching astrai/inference/cache.py & scheduler.py.""" from manim import * Text.set_default(font="Times New Roman") class PrefixCache(Scene): """Animates PagedCache exact logic: alloc→write→free with real code details.""" def _small_text(self, text, size=10, color=GRAY, **kwargs): return Text(text, font_size=size, color=color, **kwargs) def _page_box(self, pos, label, color, sz=0.44): s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6) s.move_to(pos) lbl = self._small_text(label, 10, color).move_to(pos) return VGroup(s, lbl) def _make_ptable(self, x, y, rows, color, label): """Draw OS-style page table. Returns (outline, entry_group, phy_cell_centers).""" n = len(rows) w = 2.6 h = 0.40 vx = w * 0.40 group = VGroup() # Label lbl = self._small_text(label, 12, color, weight=BOLD) lbl.move_to([x + w / 2, y + 0.15, 0]) group.add(lbl) # Outer rect (header+data rows) r = Rectangle(width=w, height=h * (n + 1), color=color, stroke_width=1.8, fill_opacity=0.06) r.move_to([x + w / 2, y - h * (n + 1) / 2, 0]) group.add(r) # H dividers for ri in range(1, n + 1): ly = y - ri * h group.add(Line([x, ly, 0], [x + w, ly, 0], color=color, stroke_width=0.8)) # V divider group.add(Line([x + vx, y, 0], [x + vx, y - h * (n + 1), 0], color=color, stroke_width=0.8)) # Column headers lh = self._small_text("logical", 10, color, weight=BOLD) lh.move_to([x + vx / 2, y - h / 2, 0]) ph = self._small_text("physical", 10, color, weight=BOLD) ph.move_to([x + vx + (w - vx) / 2, y - h / 2, 0]) group.add(lh, ph) # Row entries entries = VGroup() phy_centers = [] for i, (lg, phy) in enumerate(rows): cy = y - h / 2 - (i + 1) * h lt = self._small_text(lg, 11, color) lt.move_to([x + vx / 2, cy, 0]) pt = self._small_text(phy, 11, color) pt.move_to([x + vx + (w - vx) / 2, cy, 0]) entries.add(VGroup(lt, pt)) phy_centers.append([x + vx + (w - vx) / 2, cy, 0]) return group, entries, phy_centers, r def construct(self): # ── Scene setup ── title = self._small_text("Paged KV Cache — astrai/inference/cache.py", 26, BLUE) title.to_edge(UP, buff=0.15) self.play(Write(title)) self.wait(0.1) right_x = 4.5 step_y = 3.0 def step_msg(text, color=YELLOW): nonlocal step_y m = self._small_text(text, 12, color) m.move_to([right_x, step_y, 0]) step_y -= 0.55 self.play(Write(m)) return m def fade(m): self.play(FadeOut(m)) # ═══════════════════════════════════════════ # Phase 0: Initialize PagedCache # ═══════════════════════════════════════════ s0 = step_msg("PagedCache(max_batch=8, page_size=64, head_dim=64, n_kv_heads=4)") # Show constructor signature ctor = self._small_text( 'PagedCache(n_layers=24, n_pages=8, page_size=64,\n' ' n_kv_heads=4, head_dim=64, device=cuda, dtype=bfloat16)', 9, GRAY ) ctor.move_to([-5.5, 2.4, 0]) self.play(Write(ctor)) # Show the tensor shape tensor_shape = self._small_text( 'k_cache: [24, 8, 64, 4, 64] v_cache: [24, 8, 64, 4, 64]', 8, GRAY ) tensor_shape.move_to([-5.5, 2.05, 0]) self.play(Write(tensor_shape)) self.wait(0.3) fade(s0) # Keep ctor & tensor_shape visible # ═══════════════════════════════════════════ # Physical page pool # ═══════════════════════════════════════════ s1 = step_msg("8 page frames, refs[8] = [0,0,0,0,0,0,0,0], free_mask = (1<<8)-1 = 0xFF") pool_y = 1.45 pool_x0 = -3.8 pool_sp = 0.68 pool_pages = [] pool_pos = [] for i in range(8): x = pool_x0 + i * pool_sp pos = np.array([x, pool_y, 0]) pool_pos.append(pos) pb = self._page_box(pos, str(i), GRAY) pool_pages.append(pb) self.play(FadeIn(pb, scale=0.5), run_time=0.04) self.wait(0.1) # Bracket brack = Brace(VGroup(*[p[0] for p in pool_pages]), DOWN, buff=0.05) blbl = Text("page frames [0..7] each holds 64 KV slots", font_size=10, color=GRAY) blbl.next_to(brack, DOWN, buff=0.02) self.play(Create(brack), Write(blbl), run_time=0.25) # Free bitmask — with bit position labels mask_y = 0.55 bit_labels = VGroup() for i in range(8): bl = self._small_text(f"bit{i}", 7, DARK_GRAY) bl.move_to([pool_x0 + i * pool_sp, mask_y + 0.25, 0]) bit_labels.add(bl) self.play(Write(bit_labels), run_time=0.2) mask = self._small_text("11111111 (1 = free, 0 = alloc)", 11, GRAY) mask.move_to([-1.5, mask_y, 0]) self.play(Write(mask)) self.wait(0.15) refs_lbl = self._small_text("refs = [0,0,0,0,0,0,0,0]", 9, GRAY) refs_lbl.move_to([-1.5, mask_y - 0.3, 0]) self.play(Write(refs_lbl)) def update_mask(bits, desc): m2 = self._small_text(f"{bits} (1 = free, 0 = alloc)", 11, GRAY) m2.move_to([-1.5, mask_y, 0]) self.play(Transform(mask, m2)) if desc: self.play(Transform(refs_lbl, self._small_text(desc, 9, GRAY).move_to([-1.5, mask_y - 0.3, 0]))) fade(s1) # ═══════════════════════════════════════════ # Phase 1: Cleanup — nothing to do initially # ═══════════════════════════════════════════ s_phase = self._small_text("Phase 1: Cleanup (no finished tasks)", 11, GRAY) s_phase.move_to([right_x, step_y, 0]) step_y -= 0.4 self.play(Write(s_phase)) self.wait(0.2) # ═══════════════════════════════════════════ # Phase 2: Refill — Request A arrives # ═══════════════════════════════════════════ step_y2 = step_y s2a = step_msg("Phase 2: Refill — Request A arrives (prompt_len=120 tokens)", GREEN) s2b = step_msg("_n_pages_for(120) = (120 + 64 - 1) // 64 = 183 // 64 = 2", GREEN) calc_box = Rectangle(width=4.8, height=1.0, color=GREEN_E, stroke_width=1.2, fill_opacity=0.05) calc_box.move_to([right_x - 0.3, step_y - 0.1, 0]) calc_lines = VGroup( self._small_text("def _n_pages_for(n_tokens):", 9, GREEN), self._small_text(" return (n_tokens + page_size - 1) // page_size", 9, GREEN), self._small_text("_n_pages_for(120) = (120 + 64 - 1) // 64 = 2", 9, WHITE), ) calc_lines.arrange(DOWN, buff=0.08, aligned_edge=LEFT) calc_lines.move_to(calc_box.get_center()) calc_grp = VGroup(calc_box, calc_lines) self.play(Create(calc_grp), run_time=0.35) step_y -= 0.6 # alloc() in action s2c = step_msg("alloc_n(2) → calls alloc() twice", GREEN) alloc_code = VGroup( self._small_text("def alloc(self) -> int:", 9, GREEN), self._small_text(" lsb = self._free_mask & -self._free_mask", 9, GREEN), self._small_text(" if lsb == 0: return -1", 9, GREEN), self._small_text(" idx = lsb.bit_length() - 1", 9, GREEN), self._small_text(" self._free_mask ^= lsb", 9, GREEN), self._small_text(" self._refs[idx] = 1", 9, GREEN), self._small_text(" return idx", 9, GREEN), self._small_text("", 6, GREEN), self._small_text("1st alloc: free_mask=11111111, lsb=1, idx=0 → page 0", 9, WHITE), self._small_text("2nd alloc: free_mask=11111110, lsb=2, idx=1 → page 1", 9, WHITE), ) alloc_code.arrange(DOWN, buff=0.06, aligned_edge=LEFT) acbox = Rectangle(width=5.0, height=2.6, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) acbox.move_to([right_x - 0.2, step_y - 0.8, 0]) alloc_code.move_to(acbox.get_center()) alloc_grp = VGroup(acbox, alloc_code) self.play(Create(alloc_grp), run_time=0.4) step_y -= 1.8 # Now draw page table A (2 rows) tblA_x = -4.4 tblA_y = -0.35 tblA_outline, tblA_entries, tblA_phys, tblA_rect = self._make_ptable( tblA_x, tblA_y, [("0", "0"), ("1", "1")], GREEN, "Task A page_table" ) self.play(Create(tblA_outline), run_time=0.35) for e in tblA_entries: self.play(FadeIn(e), run_time=0.08) # Alloc effect on pool + bitmask for idx in [0, 1]: pg = pool_pages[idx][0] self.play(pg.animate.set_fill(GREEN, opacity=0.35), run_time=0.1) flash = SurroundingRectangle(pool_pages[idx], color=GREEN, buff=0.04) self.play(Create(flash), run_time=0.05) self.play(FadeOut(flash), run_time=0.04) update_mask("11111100", "refs = [1,1,0,0,0,0,0,0]") self.wait(0.2) # mapping arrows arr_a0 = Arrow( [tblA_phys[0][0] + 0.1, tblA_phys[0][1], 0], [pool_pos[0][0], pool_pos[0][1] - 0.22, 0], color=GREEN, stroke_width=1.5, buff=0.03, max_tip_length_to_length_ratio=0.18, ) arr_a1 = Arrow( [tblA_phys[1][0] + 0.1, tblA_phys[1][1], 0], [pool_pos[1][0], pool_pos[1][1] - 0.22, 0], color=GREEN, stroke_width=1.5, buff=0.03, max_tip_length_to_length_ratio=0.18, ) self.play(GrowArrow(arr_a0), GrowArrow(arr_a1), run_time=0.35) self.wait(0.3) # ═══════════════════════════════════════════ # Phase 3: Prefill (write KV into allocated pages) # ═══════════════════════════════════════════ s3a = step_msg("Phase 3: Prefill — write() KV into pages 0,1", GREEN) write_code = VGroup( self._small_text("def write(self, layer_id, page_table, start_pos, k, v):", 9, GREEN), self._small_text(" first_page = start_pos // page_size # 0 // 64 = 0", 9, GREEN), self._small_text(" last_page = (start_pos+seq_len-1)//page_size", 9, WHITE), self._small_text(" for pi in range(first_page, last_page+1):", 9, WHITE), self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE), self._small_text(" chunk = min(page_start+page_size, start_pos+seq_len) -", 9, WHITE), self._small_text(" max(page_start, start_pos)", 9, WHITE), self._small_text(" k_cache[layer_id, phys_pages, offset:offset+chunk] = k", 9, GREEN), ) write_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) wbox = Rectangle(width=5.0, height=2.3, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) wbox.move_to([right_x - 0.2, step_y - 0.4, 0]) write_code.move_to(wbox.get_center()) write_grp = VGroup(wbox, write_code) self.play(Create(write_grp), run_time=0.35) step_y -= 1.6 # Show a KV block "written" onto the pages k_written = self._small_text("KV written", 8, GREEN) k_written.move_to([pool_pos[0][0] + 0.5, pool_pos[0][1] + 0.6, 0]) self.play(Write(k_written)) self.wait(0.3) fade(k_written) s3b = step_msg("CacheView bundles cache + page_table + total_len for attention", GREEN) cv_code = VGroup( self._small_text("class CacheView:", 9, GREEN), self._small_text(" def __init__(self, cache, page_table, total_len):", 9, GREEN), self._small_text(' def write(self, layer_id, start_pos, k, v):', 9, WHITE), self._small_text(" self._cache.write(layer_id, self._page_table, ...)", 9, WHITE), self._small_text(' def gather(self, layer_id):', 9, WHITE), self._small_text(" for pi in range(page_table.size(1)):", 9, WHITE), self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE), self._small_text(" k_parts.append(k_cache[layer_id, phys_pages])", 9, WHITE), self._small_text(" k = torch.cat(k_parts, dim=1)", 9, GREEN), ) cv_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) cvbox = Rectangle(width=5.0, height=2.5, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) cvbox.move_to([right_x - 0.2, step_y - 0.6, 0]) cv_code.move_to(cvbox.get_center()) cv_grp = VGroup(cvbox, cv_code) self.play(Create(cv_grp), run_time=0.35) step_y -= 1.8 # ═══════════════════════════════════════════ # Request B arrives # ═══════════════════════════════════════════ s4 = step_msg("Phase 2: Refill — Request B arrives (prompt_len=90)", ORANGE) s4b = step_msg("alloc_n(2) → pages 2, 3", ORANGE) tblB_x = -4.4 tblB_y = -1.9 tblB_outline, tblB_entries, tblB_phys, tblB_rect = self._make_ptable( tblB_x, tblB_y, [("0", "2"), ("1", "3")], ORANGE, "Task B page_table" ) self.play(Create(tblB_outline), run_time=0.3) for e in tblB_entries: self.play(FadeIn(e), run_time=0.08) for idx in [2, 3]: pg = pool_pages[idx][0] self.play(pg.animate.set_fill(ORANGE, opacity=0.35), run_time=0.1) flash = SurroundingRectangle(pool_pages[idx], color=ORANGE, buff=0.04) self.play(Create(flash), run_time=0.05) self.play(FadeOut(flash), run_time=0.04) update_mask("11110000", "refs = [1,1,1,1,0,0,0,0]") arr_b0 = Arrow( [tblB_phys[0][0] + 0.1, tblB_phys[0][1], 0], [pool_pos[2][0], pool_pos[2][1] - 0.22, 0], color=ORANGE, stroke_width=1.5, buff=0.03, max_tip_length_to_length_ratio=0.18, ) arr_b1 = Arrow( [tblB_phys[1][0] + 0.1, tblB_phys[1][1], 0], [pool_pos[3][0], pool_pos[3][1] - 0.22, 0], color=ORANGE, stroke_width=1.5, buff=0.03, max_tip_length_to_length_ratio=0.18, ) self.play(GrowArrow(arr_b0), GrowArrow(arr_b1), run_time=0.35) self.wait(0.3) # ═══════════════════════════════════════════ # Phase 4: Decode — on-demand page growth # ═══════════════════════════════════════════ s5 = step_msg("Phase 4: Decode — Task A generates more tokens", PINK) s5b = step_msg("total_len grows: 120 → 180 tokens", PINK) s5c = step_msg("_n_pages_for(180) = (180+63)//64 = 243//64 = 3", PINK) # _maybe_alloc_page logic map_code = VGroup( self._small_text("def _maybe_alloc_page(self, task, pos):", 9, PINK), self._small_text(" needed = _n_pages_for(pos + 1) # _n_pages_for(181) = 3", 9, PINK), self._small_text(" while task.n_pages < needed: # 2 < 3", 9, PINK), self._small_text(" p = self.page_cache.alloc() # alloc page 4", 9, PINK), self._small_text(" task.page_table.append(p)", 9, PINK), self._small_text(" task.n_pages += 1 # 2 → 3", 9, PINK), self._small_text("", 5, PINK), self._small_text("page_table_A: [0, 1] → [0, 1, 4]", 9, WHITE), ) map_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) mbox = Rectangle(width=5.0, height=2.2, color=PINK, stroke_width=1.0, fill_opacity=0.03) mbox.move_to([right_x - 0.2, step_y - 0.2, 0]) map_code.move_to(mbox.get_center()) map_grp = VGroup(mbox, map_code) self.play(Create(map_grp), run_time=0.35) step_y -= 1.6 # Alloc page 4 idx = 4 pg = pool_pages[idx][0] self.play(pg.animate.set_fill(PINK, opacity=0.35), run_time=0.1) flash = SurroundingRectangle(pool_pages[idx], color=PINK, buff=0.04) self.play(Create(flash), run_time=0.05) self.play(FadeOut(flash), run_time=0.04) update_mask("11100000", "refs = [1,1,1,1,1,0,0,0]") # Expand Table A: add row 3 new_bottom = tblA_y - 0.40 * 4 # 3 data rows + 1 header line_3 = Line( [tblA_x, tblA_y - 0.40 * 3, 0], [tblA_x + 2.6, tblA_y - 0.40 * 3, 0], color=GREEN, stroke_width=0.8, ) lc3 = [tblA_x + 2.6 * 0.20, tblA_y - 0.40 * 3 - 0.20, 0] pc3 = [tblA_x + 2.6 * 0.40 + (2.6 - 2.6 * 0.40) / 2, tblA_y - 0.40 * 3 - 0.20, 0] lt3 = self._small_text("2", 11, PINK).move_to(lc3) pt3 = self._small_text("4", 11, PINK).move_to(pc3) self.play(Create(line_3), run_time=0.08) self.play(FadeIn(lt3), FadeIn(pt3), run_time=0.08) arr_c = Arrow( [pc3[0] + 0.1, pc3[1], 0], [pool_pos[4][0], pool_pos[4][1] - 0.22, 0], color=PINK, stroke_width=1.5, buff=0.03, max_tip_length_to_length_ratio=0.18, ) self.play(GrowArrow(arr_c), run_time=0.2) self.wait(0.3) # Highlight: page_table list conversion pt_list_old = self._small_text("page_table_A: [0, 1] (2 pages → can hold 128 tokens)", 10, GREEN) pt_list_old.move_to([-4.4, tblB_y - 1.35, 0]) self.play(Write(pt_list_old)) pt_list_new = self._small_text("page_table_A: [0, 1, 4] (3 pages → can hold 192 tokens)", 10, GREEN) pt_list_new.move_to([-4.4, tblB_y - 1.7, 0]) self.play(Write(pt_list_new)) self.wait(0.4) # ═══════════════════════════════════════════ # Task A finished → free() # ═══════════════════════════════════════════ s6 = step_msg("Task A done → free pages 0, 1, 4", YELLOW) s6b = step_msg("free() → refs[idx] -= 1; if refs[idx]==0: mask |= 1<