"""AstrAI promo: Paged KV Cache — astrai/inference/cache.py & scheduler.py.""" from manim import * Text.set_default(font="Times New Roman") class _TaskRow: """Manages one task's logical-page row: label, blocks, arrows.""" COLS = [-4.0, -3.2, -2.4, -1.6] def __init__(self, scene, label, color, y, pool_pos, pool_y): self.scene = scene self.color = color self.y = y self.pool_pos = pool_pos self.pool_y = pool_y self._next_col = 0 self.blocks = VGroup() self.arrows = VGroup() lbl = scene._small(label, 11, color, weight=BOLD) lbl.move_to([-5.2, y, 0]) self._label = lbl self.blocks.add(lbl) def arrive(self, *phys_idxs): self.scene.play(Write(self._label)) for pid in phys_idxs: self._add(pid, expand=False) def expand(self, phys_idx): self._add(phys_idx, expand=True) def _add(self, phys_idx, expand): col = self._next_col self._next_col += 1 x = self.COLS[col] pos = np.array([x, self.y, 0]) pb = self.scene._lp_box(pos, str(col), self.color) self.blocks.add(pb) arr = Arrow( [x, self.y + 0.19, 0], [self.pool_pos[phys_idx][0], self.pool_y - 0.22, 0], color=self.color, stroke_width=1.5, buff=0.03, max_tip_length_to_length_ratio=0.12, ) self.arrows.add(arr) if expand: target = pb.copy() pb.scale(0) self.scene.add(pb) self.scene.play(Transform(pb, target), GrowArrow(arr), run_time=0.3) else: self.scene.play(FadeIn(pb, scale=0.5), GrowArrow(arr), run_time=0.12) def finish(self): self.scene.play(FadeOut(self.blocks), FadeOut(self.arrows)) class PagedCache(Scene): def _small(self, text, size=10, color=GRAY, **kwargs): return Text(text, font_size=size, color=color, **kwargs) def _page_box(self, pos, label, color, sz=0.44): s = Square(side_length=sz, color=color, fill_opacity=0.12, stroke_width=1.6) s.move_to(pos) lbl = self._small(label, 10, color).move_to(pos) return VGroup(s, lbl) def _lp_box(self, pos, label, color, sz=0.38): s = RoundedRectangle(width=sz, height=sz, corner_radius=0.06, color=color, fill_opacity=0.22, stroke_width=1.6) s.move_to(pos) lbl = self._small(label, 12, color).move_to(pos) return VGroup(s, lbl) def construct(self): title = Text("Paged KV Cache", font_size=20, color=BLUE) title.to_edge(UP, buff=0.15) self.play(Write(title)) pool_y = 1.45; pool_x0 = -3.8; sp = 0.68 pool_pages = []; pool_pos = [] for i in range(8): x = pool_x0 + i * sp pos = np.array([x, pool_y, 0]) pool_pos.append(pos) pb = self._page_box(pos, str(i), GRAY) pool_pages.append(pb) self.play(FadeIn(pb, scale=0.5), run_time=0.04) self.wait(0.1) plbl = self._small("page frames [0..7]", 9, GRAY) plbl.next_to(pool_pages[0][0], DOWN, buff=0.25).shift(LEFT * 0.3) self.play(Write(plbl)) mask = self._small("free: 11111111", 10, GRAY) mask.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT) self.play(Write(mask)) def alloc(idx, color): pg = pool_pages[idx][0] self.play(pg.animate.set_fill(color, opacity=0.35), run_time=0.1) flash = SurroundingRectangle(pool_pages[idx], color=color, buff=0.04) self.play(Create(flash), run_time=0.05) self.play(FadeOut(flash), run_time=0.04) def set_mask(bits): m2 = self._small(f"free: {bits}", 10, GRAY) m2.next_to(plbl, DOWN, buff=0.1, aligned_edge=LEFT) self.play(Transform(mask, m2)) a_y = 0.25; b_y = -0.45; c_y = -1.15 # ── A arrives ── alloc(0, GREEN); alloc(1, GREEN); set_mask("11111100") a = _TaskRow(self, "A", GREEN, a_y, pool_pos, pool_y) a.arrive(0, 1) # ── A expands 1 ── alloc(4, GREEN); set_mask("11101100") a.expand(4) # ── B arrives ── alloc(2, ORANGE); alloc(3, ORANGE); set_mask("11100000") b = _TaskRow(self, "B", ORANGE, b_y, pool_pos, pool_y) b.arrive(2, 3) # ── C arrives ── alloc(5, BLUE); alloc(6, BLUE); set_mask("10000000") c = _TaskRow(self, "C", BLUE, c_y, pool_pos, pool_y) c.arrive(5, 6) # ── A expands 2 ── alloc(7, GREEN); set_mask("00000000") a.expand(7) # ── A finishes ── a.finish() for idx in [0, 1, 4, 7]: pg = pool_pages[idx][0] self.play(pg.animate.set_fill(GRAY, opacity=0.12), run_time=0.08) flash = SurroundingRectangle(pool_pages[idx], color=YELLOW, buff=0.04) self.play(Create(flash), run_time=0.06) self.play(FadeOut(flash), run_time=0.04) set_mask("10010011") # ── B expands (reuse) ── alloc(0, ORANGE); set_mask("10010010") b.expand(0) self.wait(0.5) s = Text("Page-table-indirected, O(1) alloc/free, on-demand growth", font_size=12, color=GREEN) s.move_to([-3.0, -2.0, 0]) self.play(Write(s)) self.wait(2) self.play(*[FadeOut(m) for m in self.mobjects])