diff --git a/architecture.py b/architecture.py index 9fe2009..56de9dc 100644 --- a/architecture.py +++ b/architecture.py @@ -58,8 +58,8 @@ class Architecture(Scene): self.wait(0.5) hl = SurroundingRectangle(layers[3], color=GREEN, buff=0.12) - hl_note = Text("Zero-Copy Prefix Reuse", font_size=22, color=GREEN) - hl_note.next_to(hl, RIGHT, buff=0.8) + hl_note = Text("Zero-Copy Prefix Reuse", font_size=18, color=GREEN) + hl_note.next_to(hl, LEFT, buff=0.4) self.play(Create(hl), Write(hl_note)) self.wait(1.5) self.play(FadeOut(hl), FadeOut(hl_note)) diff --git a/continuous_batching.py b/continuous_batching.py index 264259b..016cd32 100644 --- a/continuous_batching.py +++ b/continuous_batching.py @@ -1,98 +1,158 @@ """AstrAI promo: Continuous Batching animation. -Shows how tasks flow through the 4-phase pipeline and get batched together. +Shows 4-phase pipeline with multiple requests concurrently at different stages, +and position-grouped decode batching — the key advantage over static batching. """ from manim import * class ContinuousBatching(Scene): - """Animates tasks flowing through the prefill->decode pipeline.""" - def construct(self): - # ── title ── title = Text("Continuous Batching", font_size=48, color=BLUE) self.play(Write(title)) - self.wait(0.5) - self.play(title.animate.to_edge(UP).scale(0.6)) - top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN) + self.wait(0.3) + self.play(title.animate.to_edge(UP).scale(0.55)) + + top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2) self.play(Create(top_bar)) - # ── pipeline stages ── - stage_names = ["Waiting\nQueue", "Prefill", "Decode\n(Batched)", "Finished"] - stage_color = [GRAY, BLUE, YELLOW, GREEN] + # ── 4-phase loop (vertical) ── + phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"] + phase_colors = [GRAY, ORANGE, BLUE, YELLOW] + phase_descs = [ + "Evict finished slots", + "Admit new requests", + "Compute KV cache", + "Group by position", + ] - stages = VGroup() - arrows = VGroup() - for i, (name, color) in enumerate(zip(stage_names, stage_color)): - box = Rectangle(height=1.5, width=2.5, color=color, fill_opacity=0.12) + phases = VGroup() + phase_arrows = VGroup() + for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)): + box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12) lbl = Text(name, font_size=18, color=color) grp = VGroup(box, lbl) - grp.shift(RIGHT * (i - 1.5) * 3.2 + DOWN * 0.5) - stages.add(grp) - self.play(Create(grp), run_time=0.35) + phases.add(grp) if i > 0: - a = Arrow(stages[i - 1].get_right(), stages[i].get_left(), color=GRAY) - arrows.add(a) - self.play(Create(a), run_time=0.2) + a = Arrow( + phases[i - 1].get_bottom(), + phases[i].get_top(), + color=GRAY, buff=0.08, + max_tip_length_to_length_ratio=0.2, + ) + phase_arrows.add(a) - pipeline = VGroup(stages, arrows) - plabel = Text("4-Phase Generation Loop", font_size=16, color=GRAY).next_to( - pipeline, DOWN, buff=0.4 + phases.arrange(DOWN, buff=0.25) + phases.shift(LEFT * 3.5 + DOWN * 0.6) + + for i in range(4): + self.play(Create(phases[i])) + if i > 0: + self.play(Create(phase_arrows[i - 1])) + self.wait(0.3) + + # cycle arrow back from Decode to Cleanup + loop_arrow = CurvedArrow( + phases[-1].get_right() + RIGHT * 0.15, + phases[0].get_right() + RIGHT * 0.15, + color=GRAY, angle=PI / 2, ) - self.play(Write(plabel)) - self.wait(0.5) + loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1) + self.play(Create(loop_arrow), Write(loop_label)) + self.wait(0.3) - # ── spawn tasks ── - task_colors = [YELLOW, ORANGE, PINK, TEAL, GREEN] - tasks = VGroup() - box_center = stages[0].get_center() - for i, c in enumerate(task_colors): - dot = Dot(color=c, radius=0.12) - y_off = (i - 2) * 0.2 - dot.move_to(box_center + RIGHT * y_off * 0.3) - lbl = Text(f"R{i+1}", font_size=10, color=c).next_to(dot, UP, buff=0.1) - tg = VGroup(dot, lbl) - tasks.add(tg) - self.play(FadeIn(tg, scale=0.5), run_time=0.12) + # ── animate requests at different stages concurrently ── + colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE] + requests = [] + + def make_req(name, color): + dot = Dot(color=color, radius=0.13) + lbl = Text(name, font_size=14, color=color) + lbl.next_to(dot, LEFT, buff=0.15) + return VGroup(dot, lbl) + + # R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!) + r1 = make_req("R1", colors[0]) + r1.next_to(phases[2], RIGHT, buff=1.2) + r2 = make_req("R2", colors[1]) + r2.next_to(phases[3], RIGHT, buff=1.2) + r3 = make_req("R3", colors[2]) + r3.next_to(phases[1], RIGHT, buff=1.2) + + r4 = make_req("R4", colors[3]) + r4.next_to(phases[0], RIGHT, buff=1.2) + + self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7), + FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7)) + self.wait(0.4) + + concurrent_note = Text("3 requests at different phases simultaneously", + font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5) + self.play(Write(concurrent_note)) + self.wait(1.2) + self.play(FadeOut(concurrent_note)) + + # ── animate rotation through phases ── + # R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill + self.play( + r1.animate.next_to(phases[3], RIGHT, buff=1.2), + r2.animate.next_to(phases[0], RIGHT, buff=1.2), + r3.animate.next_to(phases[2], RIGHT, buff=1.2), + r4.animate.next_to(phases[1], RIGHT, buff=1.2), + ) + self.wait(0.3) + + # R2 done -> spawn R5 + new_r5 = make_req("R5", colors[4]) + new_r5.next_to(phases[1], RIGHT, buff=1.2) + self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7)) + r2_target = new_r5 self.wait(0.3) - # ── animate through stages ── - for phase in range(1, 4): - target = stages[phase].get_center() - anims = [t.animate.move_to(target) for t in tasks] - self.play(*anims, run_time=0.5, rate_func=smooth) - self.wait(0.15) - - # ── highlight decode batching ── - ring = SurroundingRectangle(stages[2], color=YELLOW, buff=0.12) - note = Text( - "Same-position batch decoding", font_size=16, color=YELLOW - ).next_to(stages[2], DOWN, buff=0.5) - self.play(Create(ring), Write(note)) - self.wait(1) - self.play(FadeOut(ring), FadeOut(note)) - - # ── throughput comparison (text) ── + # ── highlight decode: position-grouped batching ── + # gather multiple requests into decode self.play( - *[FadeOut(t) for t in tasks], - FadeOut(pipeline), - FadeOut(plabel), - FadeOut(top_bar), + r1.animate.next_to(phases[3], RIGHT, buff=1.2), + r3.animate.next_to(phases[3], RIGHT, buff=2.5), + FadeOut(r4), FadeOut(new_r5), ) + self.wait(0.2) + + ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12) + ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch", + font_size=15, color=YELLOW, line_spacing=0.6) + ring_note.next_to(ring, DOWN, buff=0.35) + hbox = SurroundingRectangle( + VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1), + color=GREEN, buff=0.2, + ) + hbox.next_to(phases[3], RIGHT, buff=1.85) + + self.play(Create(ring), Write(ring_note)) + self.play(Create(hbox)) + self.wait(1.8) + self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox), + FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label)) + + # O(1) slot allocation highlight + bitmask_box = VGroup( + Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE), + Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY), + ).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8) + + self.play(Write(bitmask_box)) + self.wait(1.2) + + # ── clear for throughput ── + self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar]) compare = VGroup( Text("Throughput Comparison", font_size=32, color=BLUE), - Text( - "Static Batch: 1.0× (baseline)", - font_size=24, color=RED, - ), - Text( - "Continuous Batching: 3.4× (single GPU)", - font_size=24, color=GREEN, - ), + Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED), + Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN), ).arrange(DOWN, buff=0.4, aligned_edge=LEFT) self.play(Write(compare)) self.wait(2) - self.play(FadeOut(compare)) + self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar)) diff --git a/render_all.py b/render_all.py index 0b4cb23..61afa9f 100644 --- a/render_all.py +++ b/render_all.py @@ -2,6 +2,9 @@ import subprocess import sys +from pathlib import Path + +ROOT = Path(__file__).parent SCENES = [ ("transformer.py", "Transformer"), @@ -12,19 +15,21 @@ SCENES = [ def render(file_name, scene_name, quality="-qh"): + script_path = ROOT / file_name + media_dir = ROOT / "output" cmd = [ sys.executable, "-m", "manim", - f"promo/{file_name}", + str(script_path), scene_name, quality, "--media_dir", - "promo/output", + str(media_dir), ] print(f"Rendering {scene_name}...") subprocess.run(cmd, check=True) - print(f" Done → promo/output/{scene_name}.mp4") + print(f" Done → {media_dir / 'videos' / scene_name.lower()}.mp4") if __name__ == "__main__":