"""AstrAI promo: Continuous Batching — state-machine driven batch rotation. Shows a 4-state FSM (Cleanup → Refill → Prefill → Decode → Loop → Cleanup) with coloured batch tokens flowing through states, entering & leaving continuously. """ from manim import * # ── palette ── PHASE_COLORS = { "Cleanup": GRAY, "Refill": ORANGE, "Prefill": BLUE, "Decode": YELLOW, } BATCH_COLORS = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE, GOLD, MAROON] class ContinuousBatching(Scene): def construct(self): # ═══════════════════════════════════════════════════ # 0. Title # ═══════════════════════════════════════════════════ title = Text("Continuous Batching", font_size=48, color=BLUE) self.play(Write(title)) self.wait(0.4) self.play(title.animate.to_edge(UP).scale(0.55)) bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15) self.play(Create(bar)) # ═══════════════════════════════════════════════════ # 1. Build state-machine layout (vertical, 4 states) # ═══════════════════════════════════════════════════ state_names = ["Cleanup", "Refill", "Prefill", "Decode"] states = VGroup() trans_arrows = VGroup() for i, name in enumerate(state_names): box = RoundedRectangle( width=3.6, height=0.8, corner_radius=0.15, color=PHASE_COLORS[name], fill_opacity=0.12, stroke_width=2.5, ) lbl = Text(name, font_size=20, color=PHASE_COLORS[name]) states.add(VGroup(box, lbl)) states.arrange(DOWN, buff=0.3) states.shift(LEFT * 3.8 + DOWN * 0.5) for i in range(1, 4): a = Arrow( states[i - 1].get_bottom(), states[i].get_top(), color=LIGHT_GRAY, buff=0.06, max_tip_length_to_length_ratio=0.22, ) trans_arrows.add(a) for i in range(4): self.play(Create(states[i])) if i > 0: self.play(Create(trans_arrows[i - 1])) # loop arrow — Decode returns to Cleanup (multiturn decoding) loop = CurvedArrow( states[-1].get_right() + RIGHT * 0.2, states[0].get_right() + RIGHT * 0.2, color=LIGHT_GRAY, angle=PI / 2, ) loop_lbl = Text("per token", font_size=11, color=GRAY).next_to(loop, RIGHT, buff=0.08) self.play(Create(loop), Write(loop_lbl)) self.wait(0.4) # ═══════════════════════════════════════════════════ # 2. Boot tokens — initial batches placed at mid-cycle # ═══════════════════════════════════════════════════ def make_token(name: str, col: str) -> VGroup: card = RoundedRectangle(width=0.65, height=0.38, corner_radius=0.08, color=col, fill_opacity=0.35, stroke_width=1.8) txt = Text(name, font_size=13, color=col) return VGroup(card, txt) tokens = { "A": make_token("A", BATCH_COLORS[0]), "B": make_token("B", BATCH_COLORS[1]), "C": make_token("C", BATCH_COLORS[2]), } # all three at consecutive stages, Prefill is the entry point tokens["A"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill tokens["B"].move_to(states[3]).shift(RIGHT * 1.5) # Decode tokens["C"].move_to(states[0]).shift(RIGHT * 1.5) # Cleanup for t in tokens.values(): self.play(FadeIn(t, scale=0.7), run_time=0.25) self.wait(0.2) note = Text("Every request starts at Prefill", font_size=16, color=WHITE) \ .next_to(states, DOWN, buff=0.55) self.play(Write(note)) self.wait(1.0) self.play(FadeOut(note)) # ═══════════════════════════════════════════════════ # 3. Tick 1 — advance, C exits, new D enters at Prefill # ═══════════════════════════════════════════════════ slots = [ states[0].get_center() + RIGHT * 1.5, # Cleanup states[1].get_center() + RIGHT * 1.5, # Refill states[2].get_center() + RIGHT * 1.5, # Prefill states[3].get_center() + RIGHT * 1.5, # Decode ] self.play( tokens["A"].animate.move_to(slots[3]), # Prefill → Decode tokens["B"].animate.move_to(slots[0]), # Decode → Cleanup tokens["C"].animate.move_to(slots[1]), # Cleanup → Refill ) self.wait(0.3) # C (now at Refill) exits after completing the loop # new D enters at Prefill self.play(FadeOut(tokens["C"], scale=0.6)) tokens["D"] = make_token("D", BATCH_COLORS[3]) tokens["D"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry self.play(FadeIn(tokens["D"], scale=0.7)) self.wait(0.25) # ═══════════════════════════════════════════════════ # 4. Tick 2 — advance, B exits, new E enters at Prefill # ═══════════════════════════════════════════════════ self.play( tokens["D"].animate.move_to(slots[3]), # Prefill → Decode tokens["A"].animate.move_to(slots[0]), # Decode → Cleanup tokens["B"].animate.move_to(slots[1]), # Cleanup → Refill ) self.wait(0.3) self.play(FadeOut(tokens["B"], scale=0.6)) tokens["E"] = make_token("E", BATCH_COLORS[4]) tokens["E"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry self.play(FadeIn(tokens["E"], scale=0.7)) self.wait(0.25) # ═══════════════════════════════════════════════════ # 5. Tick 3 — advance, A exits, new F enters at Prefill # ═══════════════════════════════════════════════════ self.play( tokens["E"].animate.move_to(slots[3]), # Prefill → Decode tokens["D"].animate.move_to(slots[0]), # Decode → Cleanup tokens["A"].animate.move_to(slots[1]), # Cleanup → Refill ) self.wait(0.25) self.play(FadeOut(tokens["A"], scale=0.6)) tokens["F"] = make_token("F", BATCH_COLORS[5]) tokens["F"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry self.play(FadeIn(tokens["F"], scale=0.7)) self.wait(0.25) # ═══════════════════════════════════════════════════ # 6. Tick 4 — advance, F exits, new G enters at Prefill # ═══════════════════════════════════════════════════ self.play( tokens["F"].animate.move_to(slots[3]), # Prefill → Decode tokens["E"].animate.move_to(slots[0]), # Decode → Cleanup tokens["D"].animate.move_to(slots[1]), # Cleanup → Refill ) self.wait(0.25) self.play(FadeOut(tokens["D"], scale=0.6)) tokens["G"] = make_token("G", BATCH_COLORS[6]) tokens["G"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry self.play(FadeIn(tokens["G"], scale=0.7)) self.wait(0.35) # drop note: constant throughput, all enter at Prefill flow_note = Text("All requests enter at Prefill — pipeline never drains", font_size=15, color=GREEN).next_to(states, DOWN, buff=0.55) self.play(Write(flow_note)) self.wait(1.5) self.play(FadeOut(flow_note)) # clear tokens self.play(*[FadeOut(t) for t in tokens.values()]) # ═══════════════════════════════════════════════════ # 6. Position-Grouped Decode highlight # ═══════════════════════════════════════════════════ # show multiple tokens grouped at Decode d_pos = states[3].get_center() d_tokens = [ make_token("T" + str(i), BATCH_COLORS[i]) for i in range(4) ] positions = [ d_pos + RIGHT * 1.2 + UP * 0.45, d_pos + RIGHT * 1.2, d_pos + RIGHT * 2.5 + UP * 0.45, d_pos + RIGHT * 2.5, ] for i in range(4): d_tokens[i].move_to(positions[i]) self.play(FadeIn(d_tokens[i], scale=0.6), run_time=0.2) ring = SurroundingRectangle(states[3], color=YELLOW, buff=0.12, stroke_width=3) ring_txt = Text( "Position-Grouped Batching\nSame decode position → single matmul", font_size=14, color=YELLOW, line_spacing=0.6, ).next_to(states[3], DOWN, buff=0.5) self.play(Create(ring), Write(ring_txt)) self.wait(2.0) self.play(FadeOut(ring), FadeOut(ring_txt), *[FadeOut(t) for t in d_tokens]) # ═══════════════════════════════════════════════════ # 7. O(1) Bitmask Slot Allocation # ═══════════════════════════════════════════════════ bitmask_title = Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE).next_to(states, DOWN, buff=0.75) bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)", font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15) self.play(Write(bitmask_title), Write(bitmask_desc)) self.wait(1.5) # animate bitmask bits flipping bits_group = VGroup() bit_size = 0.18 for i in range(16): square = Square(side_length=bit_size * 2, color=GRAY, fill_opacity=0.0, stroke_width=1.2) if i in (2, 5, 9, 13): square.set_fill(GRAY, opacity=0.5) bits_group.add(square) bits_group.arrange(RIGHT, buff=0.06) bits_group.next_to(bitmask_desc, DOWN, buff=0.3) occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4) self.play(Create(bits_group), Write(occupied_lbl)) # flip to ~occupied flipped = VGroup() for i, sq in enumerate(bits_group): copy_sq = Square(side_length=bit_size * 2, color=GRAY, fill_opacity=0.0, stroke_width=1.2).move_to(sq) if i not in (2, 5, 9, 13): copy_sq.set_fill(GRAY, opacity=0.5) flipped.add(copy_sq) free_lbl = Text("free_slots", font_size=11, color=GREEN) \ .next_to(flipped, LEFT, buff=0.4).align_to(occupied_lbl, LEFT) self.play(Transform(bits_group, flipped), Transform(occupied_lbl, free_lbl)) self.wait(1.2) self.play(FadeOut(bits_group), FadeOut(occupied_lbl), FadeOut(bitmask_title), FadeOut(bitmask_desc)) # ═══════════════════════════════════════════════════ # 8. Throughput comparison with animated bars # ═══════════════════════════════════════════════════ self.play( *[FadeOut(m) for m in self.mobjects if m is not title and m is not bar], FadeOut(loop), FadeOut(loop_lbl), ) for s in states: self.play(FadeOut(s), run_time=0.15) for a in trans_arrows: self.play(FadeOut(a), run_time=0.15) self.wait(0.3) # ---- title ---- compare_title = Text("Throughput Comparison", font_size=30, color=BLUE) self.play(Write(compare_title)) self.wait(0.2) self.play(compare_title.animate.to_edge(UP).scale(0.55)) self.wait(0.2) # ---- bar config ---- bar_max_w = 5.0 bar_h = 0.55 row_gap = 0.8 # ratio: static = baseline (1.0), continuous = 3.4x ratio = 1.0 / 3.4 # ---- Static Batching row ---- s_label = Text("Static Batching", font_size=24, color=RED) s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5) s_bar = Rectangle(width=bar_max_w * ratio, height=bar_h, color=RED, fill_opacity=0.55, stroke_width=0) \ .align_to(s_rect, LEFT).align_to(s_rect, UP) s_num = Text("1.0x", font_size=24, color=RED) s_label.next_to(s_rect, LEFT, buff=0.4) s_num.next_to(s_rect, RIGHT, buff=0.4) # ---- Continuous Batching row ---- c_label = Text("Continuous Batching", font_size=24, color=GREEN) c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5) c_bar = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, fill_opacity=0.55, stroke_width=0) \ .align_to(c_rect, LEFT).align_to(c_rect, UP) c_num = Text("3.4x", font_size=24, color=GREEN) c_label.next_to(c_rect, LEFT, buff=0.4) c_num.next_to(c_rect, RIGHT, buff=0.4) # stack rows, centre vertically chart = VGroup( VGroup(s_label, s_rect, s_bar, s_num), VGroup(c_label, c_rect, c_bar, c_num), ) chart.arrange(DOWN, buff=row_gap, aligned_edge=LEFT).move_to(ORIGIN) self.play( Create(s_rect), Create(c_rect), Write(s_label), Write(c_label), ) self.wait(0.3) # grow bars self.play(GrowFromEdge(s_bar, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) self.play(GrowFromEdge(c_bar, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) # show values self.play(Write(s_num), Write(c_num)) self.wait(2.5) self.play(*[FadeOut(m) for m in self.mobjects])