From f5499866adeb9216d4cac8c5f6fa8bffd8f0307c Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 8 May 2026 22:38:14 +0800 Subject: [PATCH] refactor: replace prefix cache with paged KV cache across promo guide, architecture, and animation --- PROMO_GUIDE.md | 32 +-- README.md | 2 +- architecture.py | 18 +- paged_cache.py | 524 ++++++++++++++++++++++++++++++++++++++++++++++++ prefix_cache.py | 118 ----------- render_all.py | 2 +- 6 files changed, 551 insertions(+), 145 deletions(-) create mode 100644 paged_cache.py delete mode 100644 prefix_cache.py diff --git a/PROMO_GUIDE.md b/PROMO_GUIDE.md index d096277..8f49745 100644 --- a/PROMO_GUIDE.md +++ b/PROMO_GUIDE.md @@ -28,7 +28,7 @@ |------|------|---------| | **单卡可跑** | 1B 参数,RTX 3090/4090 即可运行 | 巨大服务器集群 vs 单张显卡对比 | | **连续批处理** | 动态合并请求,吞吐量 3x+ | 任务流经 Cleanup→Refill→Prefill→Decode 动画 | -| **前缀缓存零拷贝** | 相同前缀直接复用 KV,无需重算 | Radix Tree 生长动画 | +| **分页 KV 缓存** | 固定大小页表 + O(1) 分配,按需扩容 | 页表分配与写入动画 | | **OpenAI 兼容 API** | 一行代码切换 | curl 命令对比 | | **流式输出** | 逐 token 返回,低首延迟 | 终端逐字喷出效果 | | **全过程开源** | 训练+推理+权重全部开源 | GitHub 页面展示 | @@ -50,7 +50,7 @@ │ │Cleanup │→ │Refill│→ │Prefill │→ │ Decode │ │ │ └────────┘ └──────┘ └────────┘ └────────┘ │ ├──────────────────────────────────────────────────┤ -│ Prefix Cache (Radix Tree) + KV Cache │ +│ Paged KV Cache (Page Table + Page Pool) │ ├──────────────────────────────────────────────────┤ │ Transformer (24层 GQA, RoPE, SwiGLU) │ └──────────────────────────────────────────────────┘ @@ -114,17 +114,17 @@ --- -### Segment 4:前缀缓存(1:20 - 1:50) +### Segment 4:分页 KV 缓存(1:20 - 1:50) | 镜头 | 画面 | 旁白 | 时长 | |------|------|------|------| -| 4.1 | 两个请求有相同 system prompt:"你是一个AI助手" | "如果两个请求有相同的前缀——比如相同的系统提示词——" | 5s | -| 4.2 | 普通做法:两个请求各自独立计算前 20 个 token | "普通框架会各自从头计算一遍,白白浪费算力。" | 5s | -| 4.3 | Radix Tree 生长动画:第一个请求插入,第二个请求匹配共享前缀 | "AstrAI 用一颗字典树缓存所有前缀的 KV——第二个请求直接命中。" | 8s | -| 4.4 | 高亮 Slot 复用:直接用原 slot 继续写,零拷贝 | "如果原始 slot 空闲,直接原地续写,连 GPU 内存拷贝都不需要。" | 7s | -| 4.5 | 首 token 延迟对比:有缓存 vs 无缓存(-50%) | "首 token 延迟降低一半以上。" | 5s | +| 4.1 | 展示 KV 缓存是一个固定大小的张量,被划分为多个相同大小的 page | "KV 缓存不再按请求预分配——而是划分为固定大小的页。" | 5s | +| 4.2 | 请求 A 到来,通过页表分配 2 个物理页,写入数据 | "请求到达时,通过页表分配物理页,按需写入。" | 7s | +| 4.3 | 请求 B 到来,分配新页,展示页表将逻辑位置映射到不同物理页 | "页表机制让逻辑位置和物理存储解耦——不同请求的页可以分散排列。" | 8s | +| 4.4 | Decode 阶段,请求继续生成 token,展示按需分配新页(_maybe_alloc_page) | "生成过程中如果当前页写满,自动追加新页——按需扩容,不浪费显存。" | 7s | +| 4.5 | 请求结束时展示页面回收(bitmask 置位) | "请求结束后,页面通过 O(1) 位掩码回收,即刻复用。" | 3s | -**视觉素材**:`prefix_cache.py` 动画、延迟对比 +**视觉素材**:`paged_cache.py` 动画、页表分配示意 --- @@ -260,7 +260,7 @@ manim -ql promo/continuous_batching.py ContinuousBatching |------|-----------|------|---------| | `transformer.py` | `Transformer` | 模型架构:Embed → GQA → SwiGLU → ×24 → LM Head | ~35s | | `continuous_batching.py` | `ContinuousBatching` | 4 阶段流水线动画 + 吞吐对比 | ~30s | -| `prefix_cache.py` | `PrefixCache` | Radix Tree 生长 + 多分支前缀复用 | ~30s | +| `paged_cache.py` | `PrefixCache` | 分页 KV 缓存:页表分配、按需扩容、回收 | ~30s | | `architecture.py` | `Architecture` | 全栈架构逐层展开 + 数据流 | ~25s | ### 自定义动画 @@ -297,10 +297,10 @@ Text.set_default(font="Microsoft YaHei") [01:06] 只有处于相同 KV 缓存位置的任务才一起解码,从根本上避免 RoPE 位置错乱。 [01:14] 实测吞吐量提升 3 倍以上。 -[01:20] 如果两个请求有相同的前缀,普通框架会各自从头计算。 -[01:25] AstrAI 用一颗字典树缓存所有前缀的 KV——第二个请求直接命中。 -[01:33] 如果原始 slot 空闲,直接原地续写,连 GPU 内存拷贝都不需要。 -[01:40] 首 token 延迟降低一半以上。 +[01:20] 传统 KV 缓存预分配整段显存,浪费严重。 +[01:25] AstrAI 采用分页 KV 缓存——固定大小的页,通过页表间接寻址,按需分配。 +[01:33] 生成过程中页写满了自动追加,请求结束后 O(1) 回收。 +[01:40] 显存利用率大幅提升,支持更多并发请求。 [01:50] 来实际看看效果。 [01:52] (现场演示部分,自由发挥) @@ -327,7 +327,7 @@ Text.set_default(font="Microsoft YaHei") | Transformer 架构动画 | Manim 渲染 `transformer.py` | ✅ 已渲染 | | 架构动画 | Manim 渲染 `architecture.py` | ✅ 已渲染 | | 连续批处理动画 | Manim 渲染 `continuous_batching.py` | ✅ 已渲染 | -| 前缀缓存动画 | Manim 渲染 `prefix_cache.py` | ✅ 已渲染 | +| 分页缓存动画 | Manim 渲染 `paged_cache.py` | 需重新渲染 | ### 音频素材 @@ -368,6 +368,6 @@ Text.set_default(font="Microsoft YaHei") | `scripts/promo/README.md` | 动画渲染说明(已移至 promo/) | | `promo/render_all.py` | 一键渲染所有动画 | | `promo/continuous_batching.py` | 连续批处理 Manim 场景 | -| `promo/prefix_cache.py` | 前缀缓存 Manim 场景 | +| `promo/paged_cache.py` | 分页 KV 缓存 Manim 场景 | | `promo/architecture.py` | 架构总览 Manim 场景 | | `params/config.json` | 模型配置 | diff --git a/README.md b/README.md index 506e349..d2317ab 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ python promo/render_all.py |------|-------------|------|---------| | `transformer.py` | `Transformer` | GQA 注意力机制详解 (Q/K/V→RoPE→Attention→O) + 规格卡 | ~25s | | `continuous_batching.py` | `ContinuousBatching` | 4 阶段流水线 + 吞吐对比 | ~30s | -| `prefix_cache.py` | `PrefixCache` | Radix Tree 生长 + 前缀复用 | ~30s | +| `paged_cache.py` | `PrefixCache` | 分页 KV 缓存:页表分配、按需扩容、回收 | ~30s | | `architecture.py` | `Architecture` | 全栈架构逐层展开 | ~30s | ## 导入视频剪辑 diff --git a/architecture.py b/architecture.py index 1d7f24a..526c188 100644 --- a/architecture.py +++ b/architecture.py @@ -39,10 +39,10 @@ class Architecture(Scene): "Position-grouped decode · Bitmask O(1) slots"], "astrai/inference/engine.py · scheduler.py") - L3 = make_box("Prefix Cache + KV Cache", ORANGE, - ["Radix Tree prefix matching · LRU eviction", - "Slot versioning · GPU copy_() zero-copy reuse"], - "astrai/inference/scheduler.py") + L3 = make_box("Paged KV Cache", ORANGE, + ["Page-table-indirected read/write · Page pool", + "O(1) bitmask alloc/free · On-demand page growth"], + "astrai/inference/cache.py · scheduler.py") L4 = make_box("Transformer Model", PURPLE, ["24× DecoderBlock · GQA 6:1 · RoPE", @@ -67,11 +67,11 @@ class Architecture(Scene): "4-phase loop: Cleanup tasks,", "Refill batch, Prefill prompts,", "Decode tokens one by one."], - ["Prefix Cache + KV Cache", - "Caches key-value states using", - "a Radix Tree for O(n) prefix lookup.", - "Reuses matched prefixes via GPU", - "memcpy — zero recomputation."], + ["Paged KV Cache", + "Divides KV cache into fixed-size pages", + "with page-table-indirected access.", + "Per-task page tables map logical pages", + "to physical pages — O(1) alloc/free."], ["Transformer Model (1B params)", "Decoder-only Transformer with", "Grouped-Query Attention (GQA 6:1).", diff --git a/paged_cache.py b/paged_cache.py new file mode 100644 index 0000000..4f337b2 --- /dev/null +++ b/paged_cache.py @@ -0,0 +1,524 @@ +"""AstrAI promo: Paged KV Cache — matching astrai/inference/cache.py & scheduler.py.""" + +from manim import * + +Text.set_default(font="Times New Roman") + + +class PrefixCache(Scene): + """Animates PagedCache exact logic: alloc→write→free with real code details.""" + + def _small_text(self, text, size=10, color=GRAY, **kwargs): + return Text(text, font_size=size, color=color, **kwargs) + + def _page_box(self, pos, label, color, sz=0.44): + s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6) + s.move_to(pos) + lbl = self._small_text(label, 10, color).move_to(pos) + return VGroup(s, lbl) + + def _make_ptable(self, x, y, rows, color, label): + """Draw OS-style page table. Returns (outline, entry_group, phy_cell_centers).""" + n = len(rows) + w = 2.6 + h = 0.40 + vx = w * 0.40 + + group = VGroup() + + # Label + lbl = self._small_text(label, 12, color, weight=BOLD) + lbl.move_to([x + w / 2, y + 0.15, 0]) + group.add(lbl) + + # Outer rect (header+data rows) + r = Rectangle(width=w, height=h * (n + 1), color=color, stroke_width=1.8, fill_opacity=0.06) + r.move_to([x + w / 2, y - h * (n + 1) / 2, 0]) + group.add(r) + + # H dividers + for ri in range(1, n + 1): + ly = y - ri * h + group.add(Line([x, ly, 0], [x + w, ly, 0], color=color, stroke_width=0.8)) + + # V divider + group.add(Line([x + vx, y, 0], [x + vx, y - h * (n + 1), 0], color=color, stroke_width=0.8)) + + # Column headers + lh = self._small_text("logical", 10, color, weight=BOLD) + lh.move_to([x + vx / 2, y - h / 2, 0]) + ph = self._small_text("physical", 10, color, weight=BOLD) + ph.move_to([x + vx + (w - vx) / 2, y - h / 2, 0]) + group.add(lh, ph) + + # Row entries + entries = VGroup() + phy_centers = [] + for i, (lg, phy) in enumerate(rows): + cy = y - h / 2 - (i + 1) * h + lt = self._small_text(lg, 11, color) + lt.move_to([x + vx / 2, cy, 0]) + pt = self._small_text(phy, 11, color) + pt.move_to([x + vx + (w - vx) / 2, cy, 0]) + entries.add(VGroup(lt, pt)) + phy_centers.append([x + vx + (w - vx) / 2, cy, 0]) + + return group, entries, phy_centers, r + + def construct(self): + # ── Scene setup ── + title = self._small_text("Paged KV Cache — astrai/inference/cache.py", 26, BLUE) + title.to_edge(UP, buff=0.15) + self.play(Write(title)) + self.wait(0.1) + + right_x = 4.5 + step_y = 3.0 + + def step_msg(text, color=YELLOW): + nonlocal step_y + m = self._small_text(text, 12, color) + m.move_to([right_x, step_y, 0]) + step_y -= 0.55 + self.play(Write(m)) + return m + + def fade(m): + self.play(FadeOut(m)) + + # ═══════════════════════════════════════════ + # Phase 0: Initialize PagedCache + # ═══════════════════════════════════════════ + s0 = step_msg("PagedCache(max_batch=8, page_size=64, head_dim=64, n_kv_heads=4)") + + # Show constructor signature + ctor = self._small_text( + 'PagedCache(n_layers=24, n_pages=8, page_size=64,\n' + ' n_kv_heads=4, head_dim=64, device=cuda, dtype=bfloat16)', + 9, GRAY + ) + ctor.move_to([-5.5, 2.4, 0]) + self.play(Write(ctor)) + + # Show the tensor shape + tensor_shape = self._small_text( + 'k_cache: [24, 8, 64, 4, 64] v_cache: [24, 8, 64, 4, 64]', + 8, GRAY + ) + tensor_shape.move_to([-5.5, 2.05, 0]) + self.play(Write(tensor_shape)) + + self.wait(0.3) + fade(s0) + # Keep ctor & tensor_shape visible + + # ═══════════════════════════════════════════ + # Physical page pool + # ═══════════════════════════════════════════ + s1 = step_msg("8 page frames, refs[8] = [0,0,0,0,0,0,0,0], free_mask = (1<<8)-1 = 0xFF") + + pool_y = 1.45 + pool_x0 = -3.8 + pool_sp = 0.68 + pool_pages = [] + pool_pos = [] + for i in range(8): + x = pool_x0 + i * pool_sp + pos = np.array([x, pool_y, 0]) + pool_pos.append(pos) + pb = self._page_box(pos, str(i), GRAY) + pool_pages.append(pb) + self.play(FadeIn(pb, scale=0.5), run_time=0.04) + self.wait(0.1) + + # Bracket + brack = Brace(VGroup(*[p[0] for p in pool_pages]), DOWN, buff=0.05) + blbl = Text("page frames [0..7] each holds 64 KV slots", font_size=10, color=GRAY) + blbl.next_to(brack, DOWN, buff=0.02) + self.play(Create(brack), Write(blbl), run_time=0.25) + + # Free bitmask — with bit position labels + mask_y = 0.55 + bit_labels = VGroup() + for i in range(8): + bl = self._small_text(f"bit{i}", 7, DARK_GRAY) + bl.move_to([pool_x0 + i * pool_sp, mask_y + 0.25, 0]) + bit_labels.add(bl) + self.play(Write(bit_labels), run_time=0.2) + + mask = self._small_text("11111111 (1 = free, 0 = alloc)", 11, GRAY) + mask.move_to([-1.5, mask_y, 0]) + self.play(Write(mask)) + self.wait(0.15) + + refs_lbl = self._small_text("refs = [0,0,0,0,0,0,0,0]", 9, GRAY) + refs_lbl.move_to([-1.5, mask_y - 0.3, 0]) + self.play(Write(refs_lbl)) + + def update_mask(bits, desc): + m2 = self._small_text(f"{bits} (1 = free, 0 = alloc)", 11, GRAY) + m2.move_to([-1.5, mask_y, 0]) + self.play(Transform(mask, m2)) + if desc: + self.play(Transform(refs_lbl, self._small_text(desc, 9, GRAY).move_to([-1.5, mask_y - 0.3, 0]))) + + fade(s1) + + # ═══════════════════════════════════════════ + # Phase 1: Cleanup — nothing to do initially + # ═══════════════════════════════════════════ + s_phase = self._small_text("Phase 1: Cleanup (no finished tasks)", 11, GRAY) + s_phase.move_to([right_x, step_y, 0]) + step_y -= 0.4 + self.play(Write(s_phase)) + self.wait(0.2) + + # ═══════════════════════════════════════════ + # Phase 2: Refill — Request A arrives + # ═══════════════════════════════════════════ + step_y2 = step_y + s2a = step_msg("Phase 2: Refill — Request A arrives (prompt_len=120 tokens)", GREEN) + s2b = step_msg("_n_pages_for(120) = (120 + 64 - 1) // 64 = 183 // 64 = 2", GREEN) + + calc_box = Rectangle(width=4.8, height=1.0, color=GREEN_E, stroke_width=1.2, fill_opacity=0.05) + calc_box.move_to([right_x - 0.3, step_y - 0.1, 0]) + calc_lines = VGroup( + self._small_text("def _n_pages_for(n_tokens):", 9, GREEN), + self._small_text(" return (n_tokens + page_size - 1) // page_size", 9, GREEN), + self._small_text("_n_pages_for(120) = (120 + 64 - 1) // 64 = 2", 9, WHITE), + ) + calc_lines.arrange(DOWN, buff=0.08, aligned_edge=LEFT) + calc_lines.move_to(calc_box.get_center()) + calc_grp = VGroup(calc_box, calc_lines) + self.play(Create(calc_grp), run_time=0.35) + step_y -= 0.6 + + # alloc() in action + s2c = step_msg("alloc_n(2) → calls alloc() twice", GREEN) + + alloc_code = VGroup( + self._small_text("def alloc(self) -> int:", 9, GREEN), + self._small_text(" lsb = self._free_mask & -self._free_mask", 9, GREEN), + self._small_text(" if lsb == 0: return -1", 9, GREEN), + self._small_text(" idx = lsb.bit_length() - 1", 9, GREEN), + self._small_text(" self._free_mask ^= lsb", 9, GREEN), + self._small_text(" self._refs[idx] = 1", 9, GREEN), + self._small_text(" return idx", 9, GREEN), + self._small_text("", 6, GREEN), + self._small_text("1st alloc: free_mask=11111111, lsb=1, idx=0 → page 0", 9, WHITE), + self._small_text("2nd alloc: free_mask=11111110, lsb=2, idx=1 → page 1", 9, WHITE), + ) + alloc_code.arrange(DOWN, buff=0.06, aligned_edge=LEFT) + acbox = Rectangle(width=5.0, height=2.6, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) + acbox.move_to([right_x - 0.2, step_y - 0.8, 0]) + alloc_code.move_to(acbox.get_center()) + alloc_grp = VGroup(acbox, alloc_code) + self.play(Create(alloc_grp), run_time=0.4) + step_y -= 1.8 + + # Now draw page table A (2 rows) + tblA_x = -4.4 + tblA_y = -0.35 + tblA_outline, tblA_entries, tblA_phys, tblA_rect = self._make_ptable( + tblA_x, tblA_y, + [("0", "0"), ("1", "1")], + GREEN, "Task A page_table" + ) + self.play(Create(tblA_outline), run_time=0.35) + for e in tblA_entries: + self.play(FadeIn(e), run_time=0.08) + + # Alloc effect on pool + bitmask + for idx in [0, 1]: + pg = pool_pages[idx][0] + self.play(pg.animate.set_fill(GREEN, opacity=0.35), run_time=0.1) + flash = SurroundingRectangle(pool_pages[idx], color=GREEN, buff=0.04) + self.play(Create(flash), run_time=0.05) + self.play(FadeOut(flash), run_time=0.04) + + update_mask("11111100", "refs = [1,1,0,0,0,0,0,0]") + self.wait(0.2) + + # mapping arrows + arr_a0 = Arrow( + [tblA_phys[0][0] + 0.1, tblA_phys[0][1], 0], + [pool_pos[0][0], pool_pos[0][1] - 0.22, 0], + color=GREEN, stroke_width=1.5, buff=0.03, + max_tip_length_to_length_ratio=0.18, + ) + arr_a1 = Arrow( + [tblA_phys[1][0] + 0.1, tblA_phys[1][1], 0], + [pool_pos[1][0], pool_pos[1][1] - 0.22, 0], + color=GREEN, stroke_width=1.5, buff=0.03, + max_tip_length_to_length_ratio=0.18, + ) + self.play(GrowArrow(arr_a0), GrowArrow(arr_a1), run_time=0.35) + self.wait(0.3) + + # ═══════════════════════════════════════════ + # Phase 3: Prefill (write KV into allocated pages) + # ═══════════════════════════════════════════ + s3a = step_msg("Phase 3: Prefill — write() KV into pages 0,1", GREEN) + + write_code = VGroup( + self._small_text("def write(self, layer_id, page_table, start_pos, k, v):", 9, GREEN), + self._small_text(" first_page = start_pos // page_size # 0 // 64 = 0", 9, GREEN), + self._small_text(" last_page = (start_pos+seq_len-1)//page_size", 9, WHITE), + self._small_text(" for pi in range(first_page, last_page+1):", 9, WHITE), + self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE), + self._small_text(" chunk = min(page_start+page_size, start_pos+seq_len) -", 9, WHITE), + self._small_text(" max(page_start, start_pos)", 9, WHITE), + self._small_text(" k_cache[layer_id, phys_pages, offset:offset+chunk] = k", 9, GREEN), + ) + write_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) + wbox = Rectangle(width=5.0, height=2.3, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) + wbox.move_to([right_x - 0.2, step_y - 0.4, 0]) + write_code.move_to(wbox.get_center()) + write_grp = VGroup(wbox, write_code) + self.play(Create(write_grp), run_time=0.35) + step_y -= 1.6 + + # Show a KV block "written" onto the pages + k_written = self._small_text("KV written", 8, GREEN) + k_written.move_to([pool_pos[0][0] + 0.5, pool_pos[0][1] + 0.6, 0]) + self.play(Write(k_written)) + self.wait(0.3) + fade(k_written) + + s3b = step_msg("CacheView bundles cache + page_table + total_len for attention", GREEN) + cv_code = VGroup( + self._small_text("class CacheView:", 9, GREEN), + self._small_text(" def __init__(self, cache, page_table, total_len):", 9, GREEN), + self._small_text(' def write(self, layer_id, start_pos, k, v):', 9, WHITE), + self._small_text(" self._cache.write(layer_id, self._page_table, ...)", 9, WHITE), + self._small_text(' def gather(self, layer_id):', 9, WHITE), + self._small_text(" for pi in range(page_table.size(1)):", 9, WHITE), + self._small_text(" phys_pages = page_table[:, pi]", 9, WHITE), + self._small_text(" k_parts.append(k_cache[layer_id, phys_pages])", 9, WHITE), + self._small_text(" k = torch.cat(k_parts, dim=1)", 9, GREEN), + ) + cv_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) + cvbox = Rectangle(width=5.0, height=2.5, color=GREEN_E, stroke_width=1.0, fill_opacity=0.03) + cvbox.move_to([right_x - 0.2, step_y - 0.6, 0]) + cv_code.move_to(cvbox.get_center()) + cv_grp = VGroup(cvbox, cv_code) + self.play(Create(cv_grp), run_time=0.35) + step_y -= 1.8 + + # ═══════════════════════════════════════════ + # Request B arrives + # ═══════════════════════════════════════════ + s4 = step_msg("Phase 2: Refill — Request B arrives (prompt_len=90)", ORANGE) + s4b = step_msg("alloc_n(2) → pages 2, 3", ORANGE) + + tblB_x = -4.4 + tblB_y = -1.9 + tblB_outline, tblB_entries, tblB_phys, tblB_rect = self._make_ptable( + tblB_x, tblB_y, + [("0", "2"), ("1", "3")], + ORANGE, "Task B page_table" + ) + self.play(Create(tblB_outline), run_time=0.3) + for e in tblB_entries: + self.play(FadeIn(e), run_time=0.08) + + for idx in [2, 3]: + pg = pool_pages[idx][0] + self.play(pg.animate.set_fill(ORANGE, opacity=0.35), run_time=0.1) + flash = SurroundingRectangle(pool_pages[idx], color=ORANGE, buff=0.04) + self.play(Create(flash), run_time=0.05) + self.play(FadeOut(flash), run_time=0.04) + + update_mask("11110000", "refs = [1,1,1,1,0,0,0,0]") + + arr_b0 = Arrow( + [tblB_phys[0][0] + 0.1, tblB_phys[0][1], 0], + [pool_pos[2][0], pool_pos[2][1] - 0.22, 0], + color=ORANGE, stroke_width=1.5, buff=0.03, + max_tip_length_to_length_ratio=0.18, + ) + arr_b1 = Arrow( + [tblB_phys[1][0] + 0.1, tblB_phys[1][1], 0], + [pool_pos[3][0], pool_pos[3][1] - 0.22, 0], + color=ORANGE, stroke_width=1.5, buff=0.03, + max_tip_length_to_length_ratio=0.18, + ) + self.play(GrowArrow(arr_b0), GrowArrow(arr_b1), run_time=0.35) + self.wait(0.3) + + # ═══════════════════════════════════════════ + # Phase 4: Decode — on-demand page growth + # ═══════════════════════════════════════════ + s5 = step_msg("Phase 4: Decode — Task A generates more tokens", PINK) + s5b = step_msg("total_len grows: 120 → 180 tokens", PINK) + s5c = step_msg("_n_pages_for(180) = (180+63)//64 = 243//64 = 3", PINK) + + # _maybe_alloc_page logic + map_code = VGroup( + self._small_text("def _maybe_alloc_page(self, task, pos):", 9, PINK), + self._small_text(" needed = _n_pages_for(pos + 1) # _n_pages_for(181) = 3", 9, PINK), + self._small_text(" while task.n_pages < needed: # 2 < 3", 9, PINK), + self._small_text(" p = self.page_cache.alloc() # alloc page 4", 9, PINK), + self._small_text(" task.page_table.append(p)", 9, PINK), + self._small_text(" task.n_pages += 1 # 2 → 3", 9, PINK), + self._small_text("", 5, PINK), + self._small_text("page_table_A: [0, 1] → [0, 1, 4]", 9, WHITE), + ) + map_code.arrange(DOWN, buff=0.05, aligned_edge=LEFT) + mbox = Rectangle(width=5.0, height=2.2, color=PINK, stroke_width=1.0, fill_opacity=0.03) + mbox.move_to([right_x - 0.2, step_y - 0.2, 0]) + map_code.move_to(mbox.get_center()) + map_grp = VGroup(mbox, map_code) + self.play(Create(map_grp), run_time=0.35) + step_y -= 1.6 + + # Alloc page 4 + idx = 4 + pg = pool_pages[idx][0] + self.play(pg.animate.set_fill(PINK, opacity=0.35), run_time=0.1) + flash = SurroundingRectangle(pool_pages[idx], color=PINK, buff=0.04) + self.play(Create(flash), run_time=0.05) + self.play(FadeOut(flash), run_time=0.04) + + update_mask("11100000", "refs = [1,1,1,1,1,0,0,0]") + + # Expand Table A: add row 3 + new_bottom = tblA_y - 0.40 * 4 # 3 data rows + 1 header + line_3 = Line( + [tblA_x, tblA_y - 0.40 * 3, 0], + [tblA_x + 2.6, tblA_y - 0.40 * 3, 0], + color=GREEN, stroke_width=0.8, + ) + lc3 = [tblA_x + 2.6 * 0.20, tblA_y - 0.40 * 3 - 0.20, 0] + pc3 = [tblA_x + 2.6 * 0.40 + (2.6 - 2.6 * 0.40) / 2, tblA_y - 0.40 * 3 - 0.20, 0] + lt3 = self._small_text("2", 11, PINK).move_to(lc3) + pt3 = self._small_text("4", 11, PINK).move_to(pc3) + self.play(Create(line_3), run_time=0.08) + self.play(FadeIn(lt3), FadeIn(pt3), run_time=0.08) + + arr_c = Arrow( + [pc3[0] + 0.1, pc3[1], 0], + [pool_pos[4][0], pool_pos[4][1] - 0.22, 0], + color=PINK, stroke_width=1.5, buff=0.03, + max_tip_length_to_length_ratio=0.18, + ) + self.play(GrowArrow(arr_c), run_time=0.2) + self.wait(0.3) + + # Highlight: page_table list conversion + pt_list_old = self._small_text("page_table_A: [0, 1] (2 pages → can hold 128 tokens)", 10, GREEN) + pt_list_old.move_to([-4.4, tblB_y - 1.35, 0]) + self.play(Write(pt_list_old)) + + pt_list_new = self._small_text("page_table_A: [0, 1, 4] (3 pages → can hold 192 tokens)", 10, GREEN) + pt_list_new.move_to([-4.4, tblB_y - 1.7, 0]) + self.play(Write(pt_list_new)) + self.wait(0.4) + + # ═══════════════════════════════════════════ + # Task A finished → free() + # ═══════════════════════════════════════════ + s6 = step_msg("Task A done → free pages 0, 1, 4", YELLOW) + s6b = step_msg("free() → refs[idx] -= 1; if refs[idx]==0: mask |= 1<