"""AstrAI promo: Continuous Batching animation. Shows how tasks flow through the 4-phase pipeline and get batched together. """ from manim import * class ContinuousBatching(Scene): """Animates tasks flowing through the prefill->decode pipeline.""" def construct(self): # ── title ── title = Text("Continuous Batching", font_size=48, color=BLUE) self.play(Write(title)) self.wait(0.5) self.play(title.animate.to_edge(UP).scale(0.6)) top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN) self.play(Create(top_bar)) # ── pipeline stages ── stage_names = ["Waiting\nQueue", "Prefill", "Decode\n(Batched)", "Finished"] stage_color = [GRAY, BLUE, YELLOW, GREEN] stages = VGroup() arrows = VGroup() for i, (name, color) in enumerate(zip(stage_names, stage_color)): box = Rectangle(height=1.5, width=2.5, color=color, fill_opacity=0.12) lbl = Text(name, font_size=18, color=color) grp = VGroup(box, lbl) grp.shift(RIGHT * (i - 1.5) * 3.2 + DOWN * 0.5) stages.add(grp) self.play(Create(grp), run_time=0.35) if i > 0: a = Arrow(stages[i - 1].get_right(), stages[i].get_left(), color=GRAY) arrows.add(a) self.play(Create(a), run_time=0.2) pipeline = VGroup(stages, arrows) plabel = Text("4-Phase Generation Loop", font_size=16, color=GRAY).next_to( pipeline, DOWN, buff=0.4 ) self.play(Write(plabel)) self.wait(0.5) # ── spawn tasks ── task_colors = [YELLOW, ORANGE, PINK, TEAL, GREEN] tasks = VGroup() box_center = stages[0].get_center() for i, c in enumerate(task_colors): dot = Dot(color=c, radius=0.12) y_off = (i - 2) * 0.2 dot.move_to(box_center + RIGHT * y_off * 0.3) lbl = Text(f"R{i+1}", font_size=10, color=c).next_to(dot, UP, buff=0.1) tg = VGroup(dot, lbl) tasks.add(tg) self.play(FadeIn(tg, scale=0.5), run_time=0.12) self.wait(0.3) # ── animate through stages ── for phase in range(1, 4): target = stages[phase].get_center() anims = [t.animate.move_to(target) for t in tasks] self.play(*anims, run_time=0.5, rate_func=smooth) self.wait(0.15) # ── highlight decode batching ── ring = SurroundingRectangle(stages[2], color=YELLOW, buff=0.12) note = Text( "Same-position batch decoding", font_size=16, color=YELLOW ).next_to(stages[2], DOWN, buff=0.5) self.play(Create(ring), Write(note)) self.wait(1) self.play(FadeOut(ring), FadeOut(note)) # ── throughput comparison (text) ── self.play( *[FadeOut(t) for t in tasks], FadeOut(pipeline), FadeOut(plabel), FadeOut(top_bar), ) compare = VGroup( Text("Throughput Comparison", font_size=32, color=BLUE), Text( "Static Batch: 1.0× (baseline)", font_size=24, color=RED, ), Text( "Continuous Batching: 3.4× (single GPU)", font_size=24, color=GREEN, ), ).arrange(DOWN, buff=0.4, aligned_edge=LEFT) self.play(Write(compare)) self.wait(2) self.play(FadeOut(compare))