"""AstrAI promo: Continuous Batching — Static contrast → Queue → State-machine pipeline. Sections: 0. Title 1. Problem: Static Batching — requests padded, serial, GPU idle 2. Solution: Continuous Batching — pipeline + waiting queue 3. State-machine rotation with queue feeding in (in/out flow) 4. Position-Grouped Decode highlight 5. O(1) Bitmask Slot Allocation 6. Throughput comparison (animated bars) """ from manim import * # ── palette ── PHASE_COLORS = { "Cleanup": GRAY, "Refill": ORANGE, "Prefill": BLUE, "Decode": YELLOW, } BATCH_COLORS = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE, GOLD, MAROON] class ContinuousBatching(Scene): def construct(self): # ═══════════════════════════════════════════════════ # 0. Title # ═══════════════════════════════════════════════════ title = Text("Batching", font_size=48, color=BLUE) self.play(Write(title)) self.wait(0.5) self.play(title.animate.to_edge(UP).scale(0.55)) top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15) self.play(Create(top_bar)) # ═══════════════════════════════════════════════════ # 1. Problem: Static Batching # ═══════════════════════════════════════════════════ subtitle = Text("Static Batching", font_size=32, color=RED).next_to(top_bar, DOWN, buff=0.5) self.play(Write(subtitle)) self.wait(0.4) # three requests with different lengths req_w_base = 0.35 req_h = 0.4 req_gap = 0.6 reqs = [] lengths = [8, 12, 5] # different token lengths colors_req = [ORANGE, BLUE, PINK] for i, (n, c) in enumerate(zip(lengths, colors_req)): w = req_w_base * n bar = Rectangle(width=w, height=req_h, color=c, fill_opacity=0.5, stroke_width=1.5) lbl = Text(f"Req {i+1}", font_size=18, color=c) grp = VGroup(bar, lbl) lbl.next_to(bar, UP, buff=0.1) reqs.append(grp) VGroup(*reqs).arrange(DOWN, buff=0.55).move_to(ORIGIN).shift(LEFT * 1.5 + DOWN * 0.2) for r in reqs: self.play(FadeIn(r, scale=0.9), run_time=0.3) self.wait(0.5) # label: different prompt lengths diff_note = Text("Different prompt lengths", font_size=16, color=LIGHT_GRAY) \ .next_to(reqs[2], DOWN, buff=0.4) self.play(Write(diff_note)) self.wait(0.8) # animate padding — stretch all bars to max length max_w = req_w_base * max(lengths) pad_lines = VGroup() for r in reqs: cur_w = r[0].get_width() if cur_w < max_w: pad = Rectangle(width=max_w - cur_w, height=req_h, fill_opacity=0.25, stroke_width=0, color=GRAY) pad.next_to(r[0], RIGHT, buff=0) pad_lines.add(pad) self.play(*[Create(p) for p in pad_lines]) pad_note = Text("Padded to longest → waste", font_size=15, color=RED).next_to(diff_note, DOWN, buff=0.15) self.play(Write(pad_note)) self.wait(0.8) # clear notes self.play(FadeOut(diff_note), FadeOut(pad_note), *[FadeOut(p) for p in pad_lines]) # move all reqs to centre and show serial processing with idle gaps serial_slots = [ ORIGIN + LEFT * 2.5 + DOWN * 1.0, ORIGIN + DOWN * 1.0, ORIGIN + RIGHT * 2.5 + DOWN * 1.0, ] batch_box = Rectangle(width=max_w + 0.4, height=req_h + 0.25, color=RED, stroke_width=1.5, stroke_opacity=0.5) active_tex = Text("", font_size=14, color=WHITE) for i in range(3): batch_box.move_to(serial_slots[i]) reqs[i].move_to(batch_box.get_center()) self.play( reqs[i].animate.move_to(batch_box.get_center()).set_opacity(1), Create(batch_box), run_time=0.5, ) # show GPU active label if i == 0: gpu_active = Text("GPU Busy", font_size=14, color=GREEN) \ .next_to(batch_box, UP, buff=0.15) self.play(Write(gpu_active)) self.wait(0.25 if i < 2 else 0.1) # idle gap between batches if i < 2: gpu_idle = Text("Idle...", font_size=13, color=GRAY) \ .next_to(batch_box, RIGHT, buff=0.4).shift(UP * 0.1) self.play(Write(gpu_idle)) self.wait(0.6) self.play(FadeOut(gpu_idle)) else: self.play(FadeOut(gpu_active)) self.play(FadeOut(batch_box)) reqs[i].set_opacity(0.35) # big X over static batching cross = VGroup( Line(LEFT * 2 + UP * 1.5, RIGHT * 2 + DOWN * 1.5, color=RED, stroke_width=5), Line(LEFT * 2 + DOWN * 1.5, RIGHT * 2 + UP * 1.5, color=RED, stroke_width=5), ) reason = Text("Serial → GPU idle 60%+ of the time", font_size=16, color=RED) \ .next_to(cross, DOWN, buff=0.3) self.play(Create(cross), Write(reason)) self.wait(1.5) # clear everything for transition self.play( FadeOut(cross), FadeOut(reason), FadeOut(subtitle), *[FadeOut(r) for r in reqs], ) # ═══════════════════════════════════════════════════ # 2. Transition: Continuous Batching — the solution # ═══════════════════════════════════════════════════ new_title = Text("Continuous Batching", font_size=48, color=BLUE).to_edge(UP) \ .scale(0.55).move_to(title) self.play(Transform(title, new_title)) self.wait(0.5) # ═══════════════════════════════════════════════════ # 3. Build state-machine + waiting queue # ═══════════════════════════════════════════════════ state_names = ["Cleanup", "Refill", "Prefill", "Decode"] states = VGroup() trans_arrows = VGroup() for i, name in enumerate(state_names): box = RoundedRectangle( width=3.6, height=0.8, corner_radius=0.15, color=PHASE_COLORS[name], fill_opacity=0.12, stroke_width=2.5, ) lbl = Text(name, font_size=20, color=PHASE_COLORS[name]) states.add(VGroup(box, lbl)) states.arrange(DOWN, buff=0.3) states.shift(LEFT * 3.5 + DOWN * 0.4) for i in range(1, 4): a = Arrow( states[i - 1].get_bottom(), states[i].get_top(), color=LIGHT_GRAY, buff=0.06, max_tip_length_to_length_ratio=0.22, ) trans_arrows.add(a) for i in range(4): self.play(Create(states[i])) if i > 0: self.play(Create(trans_arrows[i - 1])) # loop arrow loop = CurvedArrow( states[-1].get_right() + RIGHT * 0.2, states[0].get_right() + RIGHT * 0.2, color=LIGHT_GRAY, angle=PI / 2, ) loop_lbl = Text("per token", font_size=11, color=GRAY).next_to(loop, RIGHT, buff=0.08) self.play(Create(loop), Write(loop_lbl)) self.wait(0.3) # ═══════════════════════════════════════════════════ # 4. Waiting Queue # ═══════════════════════════════════════════════════ queue_box = Rectangle(width=1.6, height=2.5, color=LIGHT_GRAY, stroke_width=1.5, fill_opacity=0.06) queue_label = Text("Queue", font_size=14, color=LIGHT_GRAY) queue_label.next_to(queue_box, UP, buff=0.12) queue_item = VGroup(queue_box, queue_label) queue_item.next_to(states[1], RIGHT, buff=2.0) # Refill level self.play(Create(queue_box), Write(queue_label)) # pending requests in queue def make_req_token(name, col): card = RoundedRectangle(width=0.65, height=0.38, corner_radius=0.08, color=col, fill_opacity=0.35, stroke_width=1.8) txt = Text(name, font_size=13, color=col) return VGroup(card, txt) queue_entries = VGroup( make_req_token("X", BATCH_COLORS[2]), make_req_token("Y", BATCH_COLORS[1]), make_req_token("Z", BATCH_COLORS[0]), ) queue_entries.arrange(DOWN, buff=0.25) queue_entries.move_to(queue_box) for q in queue_entries: self.play(FadeIn(q, scale=0.7), run_time=0.2) # arrow from queue into Refill queue_arrow = Arrow( queue_box.get_left() + LEFT * 0.05, states[1].get_right() + RIGHT * 0.15, color=ORANGE, buff=0.06, max_tip_length_to_length_ratio=0.2, ) queue_arrow_lbl = Text("admit", font_size=11, color=ORANGE) \ .next_to(queue_arrow, UP, buff=0.05) self.play(Create(queue_arrow), Write(queue_arrow_lbl)) self.wait(0.5) # ═══════════════════════════════════════════════════ # 5. Token flow — pipeline + queue # ═══════════════════════════════════════════════════ tokens = { "A": make_req_token("A", BATCH_COLORS[0]), "B": make_req_token("B", BATCH_COLORS[1]), "C": make_req_token("C", BATCH_COLORS[2]), } tokens["A"].move_to(states[2]).shift(RIGHT * 1.2) # Prefill tokens["B"].move_to(states[3]).shift(RIGHT * 1.2) # Decode tokens["C"].move_to(states[0]).shift(RIGHT * 1.2) # Cleanup for t in tokens.values(): self.play(FadeIn(t, scale=0.7), run_time=0.2) self.wait(0.3) note = Text("Every request starts at Prefill", font_size=16, color=WHITE) \ .next_to(states, DOWN, buff=0.6) self.play(Write(note)) self.wait(0.8) self.play(FadeOut(note)) slots = [ states[0].get_center() + RIGHT * 1.2, # Cleanup states[1].get_center() + RIGHT * 1.2, # Refill states[2].get_center() + RIGHT * 1.2, # Prefill states[3].get_center() + RIGHT * 1.2, # Decode ] # ---- Tick 1: advance, C exits, X pulled from queue → Prefill ---- self.play( tokens["A"].animate.move_to(slots[3]), # Prefill → Decode tokens["B"].animate.move_to(slots[0]), # Decode → Cleanup tokens["C"].animate.move_to(slots[1]), # Cleanup → Refill ) self.wait(0.3) # C exits, X leaves queue → Prefill q_first = queue_entries[0] self.play(FadeOut(tokens["C"], scale=0.6)) tokens["D"] = make_req_token("X", BATCH_COLORS[2]) tokens["D"].move_to(states[2]).shift(RIGHT * 1.2) # Prefill self.play( FadeOut(q_first, scale=0.5), FadeIn(tokens["D"], scale=0.7), ) queue_entries.remove(q_first) queue_entries.arrange(DOWN, buff=0.25).move_to(queue_box) self.wait(0.3) # ---- Tick 2: advance, B exits, Y from queue → Prefill ---- self.play( tokens["D"].animate.move_to(slots[3]), # Prefill → Decode tokens["A"].animate.move_to(slots[0]), # Decode → Cleanup tokens["B"].animate.move_to(slots[1]), # Cleanup → Refill ) self.wait(0.3) q_first = queue_entries[0] self.play(FadeOut(tokens["B"], scale=0.6)) tokens["E"] = make_req_token("Y", BATCH_COLORS[1]) tokens["E"].move_to(states[2]).shift(RIGHT * 1.2) self.play( FadeOut(q_first, scale=0.5), FadeIn(tokens["E"], scale=0.7), ) queue_entries.remove(q_first) queue_entries.arrange(DOWN, buff=0.25).move_to(queue_box) self.wait(0.3) # ---- Tick 3: advance, A exits, Z from queue → Prefill ---- self.play( tokens["E"].animate.move_to(slots[3]), # Prefill → Decode tokens["D"].animate.move_to(slots[0]), # Decode → Cleanup tokens["A"].animate.move_to(slots[1]), # Cleanup → Refill ) self.wait(0.3) q_first = queue_entries[0] self.play(FadeOut(tokens["A"], scale=0.6)) tokens["F"] = make_req_token("Z", BATCH_COLORS[2]) tokens["F"].move_to(states[2]).shift(RIGHT * 1.2) self.play( FadeOut(q_first, scale=0.5), FadeIn(tokens["F"], scale=0.7), ) queue_entries.remove(q_first) # queue is now empty — show new request W arriving self.wait(0.3) # a brand-new request W flies into the empty queue w_token = make_req_token("W", BATCH_COLORS[4]) w_token.next_to(queue_box, RIGHT, buff=0.8) self.play(w_token.animate.move_to(queue_box), run_time=0.6) queue_entries.add(w_token) self.wait(0.2) # note: continuous throughput flow_note = Text( "Pipeline never drains — new requests always entering", font_size=15, color=GREEN, ).next_to(states, DOWN, buff=0.6) self.play(Write(flow_note)) self.wait(1.5) self.play(FadeOut(flow_note)) # clean up self.play( *[FadeOut(t) for t in tokens.values()], FadeOut(queue_arrow), FadeOut(queue_arrow_lbl), FadeOut(queue_box), FadeOut(queue_label), *[FadeOut(q) for q in queue_entries], ) # ═══════════════════════════════════════════════════ # 6. Position-Grouped Decode highlight # ═══════════════════════════════════════════════════ d_pos = states[3].get_center() d_tokens = [ make_req_token("T" + str(i), BATCH_COLORS[i]) for i in range(4) ] positions = [ d_pos + RIGHT * 1.0 + UP * 0.45, d_pos + RIGHT * 1.0, d_pos + RIGHT * 2.3 + UP * 0.45, d_pos + RIGHT * 2.3, ] for i in range(4): d_tokens[i].move_to(positions[i]) self.play(FadeIn(d_tokens[i], scale=0.6), run_time=0.2) ring = SurroundingRectangle(states[3], color=YELLOW, buff=0.12, stroke_width=3) ring_txt = Text( "Position-Grouped Batching\nSame decode position → single matmul", font_size=14, color=YELLOW, line_spacing=0.6, ).next_to(states[3], DOWN, buff=0.5) self.play(Create(ring), Write(ring_txt)) self.wait(2.0) self.play(FadeOut(ring), FadeOut(ring_txt), *[FadeOut(t) for t in d_tokens]) # ═══════════════════════════════════════════════════ # 7. O(1) Bitmask Slot Allocation # ═══════════════════════════════════════════════════ bitmask_title = Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE).next_to(states, DOWN, buff=0.75) bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)", font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15) self.play(Write(bitmask_title), Write(bitmask_desc)) self.wait(1.5) bits_group = VGroup() bit_size = 0.18 for i in range(16): square = Square(side_length=bit_size * 2, color=GRAY, fill_opacity=0.0, stroke_width=1.2) if i in (2, 5, 9, 13): square.set_fill(GRAY, opacity=0.5) bits_group.add(square) bits_group.arrange(RIGHT, buff=0.06) bits_group.next_to(bitmask_desc, DOWN, buff=0.3) occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4) self.play(Create(bits_group), Write(occupied_lbl)) flipped = VGroup() for i, sq in enumerate(bits_group): copy_sq = Square(side_length=bit_size * 2, color=GRAY, fill_opacity=0.0, stroke_width=1.2).move_to(sq) if i not in (2, 5, 9, 13): copy_sq.set_fill(GRAY, opacity=0.5) flipped.add(copy_sq) free_lbl = Text("free_slots", font_size=11, color=GREEN) \ .next_to(flipped, LEFT, buff=0.4).align_to(occupied_lbl, LEFT) self.play(Transform(bits_group, flipped), Transform(occupied_lbl, free_lbl)) self.wait(1.2) self.play(FadeOut(bits_group), FadeOut(occupied_lbl), FadeOut(bitmask_title), FadeOut(bitmask_desc)) # ═══════════════════════════════════════════════════ # 8. Throughput comparison with animated bars # ═══════════════════════════════════════════════════ self.play( *[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar], FadeOut(loop), FadeOut(loop_lbl), ) for s in states: self.play(FadeOut(s), run_time=0.15) for a in trans_arrows: self.play(FadeOut(a), run_time=0.15) self.wait(0.3) compare_title = Text("Throughput Comparison", font_size=30, color=BLUE) self.play(Write(compare_title)) self.wait(0.2) self.play(compare_title.animate.to_edge(UP).scale(0.55)) self.wait(0.2) bar_max_w = 5.0 bar_h = 0.55 row_gap = 0.8 ratio = 1.0 / 3.4 s_label = Text("Static Batching", font_size=24, color=RED) s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5) s_bar = Rectangle(width=bar_max_w * ratio, height=bar_h, color=RED, fill_opacity=0.55, stroke_width=0) s_num = Text("1.0x", font_size=24, color=RED) c_label = Text("Continuous Batching", font_size=24, color=GREEN) c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5) c_bar = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, fill_opacity=0.55, stroke_width=0) c_num = Text("3.4x", font_size=24, color=GREEN) s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2)) c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2)) s_bar.align_to(s_rect, LEFT).align_to(s_rect, UP) c_bar.align_to(c_rect, LEFT).align_to(c_rect, UP) s_label.next_to(s_rect, LEFT, buff=0.4) c_label.next_to(c_rect, LEFT, buff=0.4) s_num.next_to(s_rect, RIGHT, buff=0.4) c_num.next_to(c_rect, RIGHT, buff=0.4) self.play( Create(s_rect), Create(c_rect), Write(s_label), Write(c_label), ) self.wait(0.3) self.play(GrowFromEdge(s_bar, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) self.play(GrowFromEdge(c_bar, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) self.play(Write(s_num), Write(c_num)) self.wait(2.5) self.play(*[FadeOut(m) for m in self.mobjects])