"""AstrAI promo: Continuous Batching animation. Shows 4-phase pipeline with multiple requests concurrently at different stages, and position-grouped decode batching — the key advantage over static batching. """ from manim import * class ContinuousBatching(Scene): def construct(self): title = Text("Continuous Batching", font_size=48, color=BLUE) self.play(Write(title)) self.wait(0.3) self.play(title.animate.to_edge(UP).scale(0.55)) top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2) self.play(Create(top_bar)) # ── 4-phase loop (vertical) ── phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"] phase_colors = [GRAY, ORANGE, BLUE, YELLOW] phase_descs = [ "Evict finished slots", "Admit new requests", "Compute KV cache", "Group by position", ] phases = VGroup() phase_arrows = VGroup() for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)): box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12) lbl = Text(name, font_size=18, color=color) grp = VGroup(box, lbl) phases.add(grp) if i > 0: a = Arrow( phases[i - 1].get_bottom(), phases[i].get_top(), color=GRAY, buff=0.08, max_tip_length_to_length_ratio=0.2, ) phase_arrows.add(a) phases.arrange(DOWN, buff=0.25) phases.shift(LEFT * 3.5 + DOWN * 0.6) for i in range(4): self.play(Create(phases[i])) if i > 0: self.play(Create(phase_arrows[i - 1])) self.wait(0.3) # cycle arrow back from Decode to Cleanup loop_arrow = CurvedArrow( phases[-1].get_right() + RIGHT * 0.15, phases[0].get_right() + RIGHT * 0.15, color=GRAY, angle=PI / 2, ) loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1) self.play(Create(loop_arrow), Write(loop_label)) self.wait(0.3) # ── animate requests at different stages concurrently ── colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE] requests = [] def make_req(name, color): dot = Dot(color=color, radius=0.13) lbl = Text(name, font_size=14, color=color) lbl.next_to(dot, LEFT, buff=0.15) return VGroup(dot, lbl) # R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!) r1 = make_req("R1", colors[0]) r1.next_to(phases[2], RIGHT, buff=1.2) r2 = make_req("R2", colors[1]) r2.next_to(phases[3], RIGHT, buff=1.2) r3 = make_req("R3", colors[2]) r3.next_to(phases[1], RIGHT, buff=1.2) r4 = make_req("R4", colors[3]) r4.next_to(phases[0], RIGHT, buff=1.2) self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7), FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7)) self.wait(0.4) concurrent_note = Text("3 requests at different phases simultaneously", font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5) self.play(Write(concurrent_note)) self.wait(1.2) self.play(FadeOut(concurrent_note)) # ── animate rotation through phases ── # R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill self.play( r1.animate.next_to(phases[3], RIGHT, buff=1.2), r2.animate.next_to(phases[0], RIGHT, buff=1.2), r3.animate.next_to(phases[2], RIGHT, buff=1.2), r4.animate.next_to(phases[1], RIGHT, buff=1.2), ) self.wait(0.3) # R2 done -> spawn R5 new_r5 = make_req("R5", colors[4]) new_r5.next_to(phases[1], RIGHT, buff=1.2) self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7)) r2_target = new_r5 self.wait(0.3) # ── highlight decode: position-grouped batching ── # gather multiple requests into decode self.play( r1.animate.next_to(phases[3], RIGHT, buff=1.2), r3.animate.next_to(phases[3], RIGHT, buff=2.5), FadeOut(r4), FadeOut(new_r5), ) self.wait(0.2) ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12) ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch", font_size=15, color=YELLOW, line_spacing=0.6) ring_note.next_to(ring, DOWN, buff=0.35) hbox = SurroundingRectangle( VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1), color=GREEN, buff=0.2, ) hbox.next_to(phases[3], RIGHT, buff=1.85) self.play(Create(ring), Write(ring_note)) self.play(Create(hbox)) self.wait(1.8) self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox), FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label)) # O(1) slot allocation highlight bitmask_box = VGroup( Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE), Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY), ).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8) self.play(Write(bitmask_box)) self.wait(1.2) # ── clear for throughput ── self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar]) compare = VGroup( Text("Throughput Comparison", font_size=32, color=BLUE), Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED), Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN), ).arrange(DOWN, buff=0.4, aligned_edge=LEFT) self.play(Write(compare)) self.wait(2) self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar))