diff --git a/continuous_batching.py b/continuous_batching.py index 016cd32..c7335fb 100644 --- a/continuous_batching.py +++ b/continuous_batching.py @@ -1,158 +1,325 @@ -"""AstrAI promo: Continuous Batching animation. +"""AstrAI promo: Continuous Batching — state-machine driven batch rotation. -Shows 4-phase pipeline with multiple requests concurrently at different stages, -and position-grouped decode batching — the key advantage over static batching. +Shows a 4-state FSM (Cleanup → Refill → Prefill → Decode → Loop → Cleanup) +with coloured batch tokens flowing through states, entering & leaving continuously. """ from manim import * +# ── palette ── +PHASE_COLORS = { + "Cleanup": GRAY, + "Refill": ORANGE, + "Prefill": BLUE, + "Decode": YELLOW, +} +BATCH_COLORS = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE, GOLD, MAROON] + class ContinuousBatching(Scene): def construct(self): + # ═══════════════════════════════════════════════════ + # 0. Title + # ═══════════════════════════════════════════════════ title = Text("Continuous Batching", font_size=48, color=BLUE) self.play(Write(title)) - self.wait(0.3) + self.wait(0.4) self.play(title.animate.to_edge(UP).scale(0.55)) + bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15) + self.play(Create(bar)) - top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2) - self.play(Create(top_bar)) + # ═══════════════════════════════════════════════════ + # 1. Build state-machine layout (vertical, 4 states) + # ═══════════════════════════════════════════════════ + state_names = ["Cleanup", "Refill", "Prefill", "Decode"] - # ── 4-phase loop (vertical) ── - phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"] - phase_colors = [GRAY, ORANGE, BLUE, YELLOW] - phase_descs = [ - "Evict finished slots", - "Admit new requests", - "Compute KV cache", - "Group by position", - ] + states = VGroup() + trans_arrows = VGroup() + for i, name in enumerate(state_names): + box = RoundedRectangle( + width=3.6, height=0.8, corner_radius=0.15, + color=PHASE_COLORS[name], fill_opacity=0.12, stroke_width=2.5, + ) + lbl = Text(name, font_size=20, color=PHASE_COLORS[name]) + states.add(VGroup(box, lbl)) - phases = VGroup() - phase_arrows = VGroup() - for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)): - box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12) - lbl = Text(name, font_size=18, color=color) - grp = VGroup(box, lbl) - phases.add(grp) - if i > 0: - a = Arrow( - phases[i - 1].get_bottom(), - phases[i].get_top(), - color=GRAY, buff=0.08, - max_tip_length_to_length_ratio=0.2, - ) - phase_arrows.add(a) + states.arrange(DOWN, buff=0.3) + states.shift(LEFT * 3.8 + DOWN * 0.5) - phases.arrange(DOWN, buff=0.25) - phases.shift(LEFT * 3.5 + DOWN * 0.6) + for i in range(1, 4): + a = Arrow( + states[i - 1].get_bottom(), states[i].get_top(), + color=LIGHT_GRAY, buff=0.06, + max_tip_length_to_length_ratio=0.22, + ) + trans_arrows.add(a) for i in range(4): - self.play(Create(phases[i])) + self.play(Create(states[i])) if i > 0: - self.play(Create(phase_arrows[i - 1])) - self.wait(0.3) + self.play(Create(trans_arrows[i - 1])) - # cycle arrow back from Decode to Cleanup - loop_arrow = CurvedArrow( - phases[-1].get_right() + RIGHT * 0.15, - phases[0].get_right() + RIGHT * 0.15, - color=GRAY, angle=PI / 2, + # loop arrow — Decode returns to Cleanup (multiturn decoding) + loop = CurvedArrow( + states[-1].get_right() + RIGHT * 0.2, + states[0].get_right() + RIGHT * 0.2, + color=LIGHT_GRAY, angle=PI / 2, ) - loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1) - self.play(Create(loop_arrow), Write(loop_label)) - self.wait(0.3) - - # ── animate requests at different stages concurrently ── - colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE] - requests = [] - - def make_req(name, color): - dot = Dot(color=color, radius=0.13) - lbl = Text(name, font_size=14, color=color) - lbl.next_to(dot, LEFT, buff=0.15) - return VGroup(dot, lbl) - - # R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!) - r1 = make_req("R1", colors[0]) - r1.next_to(phases[2], RIGHT, buff=1.2) - r2 = make_req("R2", colors[1]) - r2.next_to(phases[3], RIGHT, buff=1.2) - r3 = make_req("R3", colors[2]) - r3.next_to(phases[1], RIGHT, buff=1.2) - - r4 = make_req("R4", colors[3]) - r4.next_to(phases[0], RIGHT, buff=1.2) - - self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7), - FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7)) + loop_lbl = Text("per token", font_size=11, color=GRAY).next_to(loop, RIGHT, buff=0.08) + self.play(Create(loop), Write(loop_lbl)) self.wait(0.4) - concurrent_note = Text("3 requests at different phases simultaneously", - font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5) - self.play(Write(concurrent_note)) - self.wait(1.2) - self.play(FadeOut(concurrent_note)) + # ═══════════════════════════════════════════════════ + # 2. Boot tokens — initial batches placed at mid-cycle + # ═══════════════════════════════════════════════════ + def make_token(name: str, col: str) -> VGroup: + card = RoundedRectangle(width=0.65, height=0.38, corner_radius=0.08, + color=col, fill_opacity=0.35, stroke_width=1.8) + txt = Text(name, font_size=13, color=col) + return VGroup(card, txt) - # ── animate rotation through phases ── - # R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill - self.play( - r1.animate.next_to(phases[3], RIGHT, buff=1.2), - r2.animate.next_to(phases[0], RIGHT, buff=1.2), - r3.animate.next_to(phases[2], RIGHT, buff=1.2), - r4.animate.next_to(phases[1], RIGHT, buff=1.2), - ) - self.wait(0.3) + tokens = { + "A": make_token("A", BATCH_COLORS[0]), + "B": make_token("B", BATCH_COLORS[1]), + "C": make_token("C", BATCH_COLORS[2]), + "D": make_token("D", BATCH_COLORS[3]), + } + # assign to states + tokens["A"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill + tokens["B"].move_to(states[3]).shift(RIGHT * 1.5) # Decode + tokens["C"].move_to(states[1]).shift(RIGHT * 1.5) # Refill + tokens["D"].move_to(states[0]).shift(RIGHT * 1.5) # Cleanup - # R2 done -> spawn R5 - new_r5 = make_req("R5", colors[4]) - new_r5.next_to(phases[1], RIGHT, buff=1.2) - self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7)) - r2_target = new_r5 - - self.wait(0.3) - - # ── highlight decode: position-grouped batching ── - # gather multiple requests into decode - self.play( - r1.animate.next_to(phases[3], RIGHT, buff=1.2), - r3.animate.next_to(phases[3], RIGHT, buff=2.5), - FadeOut(r4), FadeOut(new_r5), - ) + for t in tokens.values(): + self.play(FadeIn(t, scale=0.7), run_time=0.25) self.wait(0.2) - ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12) - ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch", - font_size=15, color=YELLOW, line_spacing=0.6) - ring_note.next_to(ring, DOWN, buff=0.35) - hbox = SurroundingRectangle( - VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1), - color=GREEN, buff=0.2, + note = Text("4 batches distributed across 4 states", font_size=16, color=WHITE) \ + .next_to(states, DOWN, buff=0.55) + self.play(Write(note)) + self.wait(1.0) + self.play(FadeOut(note)) + + # ═══════════════════════════════════════════════════ + # 3. Tick 1 — all tokens advance one state + # A: Prefill → Decode B: Decode → Cleanup + # C: Refill → Prefill D: Cleanup → Refill + # ═══════════════════════════════════════════════════ + slots = [ + states[0].get_center() + RIGHT * 1.5, # Cleanup + states[1].get_center() + RIGHT * 1.5, # Refill + states[2].get_center() + RIGHT * 1.5, # Prefill + states[3].get_center() + RIGHT * 1.5, # Decode + ] + + self.play( + tokens["A"].animate.move_to(slots[3]), # → Decode + tokens["B"].animate.move_to(slots[0]), # → Cleanup + tokens["C"].animate.move_to(slots[2]), # → Prefill + tokens["D"].animate.move_to(slots[1]), # → Refill ) - hbox.next_to(phases[3], RIGHT, buff=1.85) + self.wait(0.3) - self.play(Create(ring), Write(ring_note)) - self.play(Create(hbox)) - self.wait(1.8) - self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox), - FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label)) + # B finished → replace with new token E + self.play(FadeOut(tokens["B"], scale=0.6)) + tokens["E"] = make_token("E", BATCH_COLORS[4]) + tokens["E"].move_to(states[1]).shift(RIGHT * 1.5) # Refill + self.play(FadeIn(tokens["E"], scale=0.7)) + self.wait(0.25) - # O(1) slot allocation highlight - bitmask_box = VGroup( - Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE), - Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY), - ).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8) + # ═══════════════════════════════════════════════════ + # 4. Tick 2 — advance again + # ═══════════════════════════════════════════════════ + self.play( + tokens["A"].animate.move_to(slots[0]), # Decode → Cleanup + tokens["D"].animate.move_to(slots[2]), # Refill → Prefill + tokens["C"].animate.move_to(slots[3]), # Prefill → Decode + tokens["E"].animate.move_to(slots[1]), # (entered) → keeps Refill + ) + self.wait(0.3) - self.play(Write(bitmask_box)) + # A finished → replace with F + self.play(FadeOut(tokens["A"], scale=0.6)) + tokens["F"] = make_token("F", BATCH_COLORS[5]) + tokens["F"].move_to(states[1]).shift(RIGHT * 1.5) + self.play(FadeIn(tokens["F"], scale=0.7)) + self.wait(0.25) + + # ═══════════════════════════════════════════════════ + # 5. Tick 3 — faster cycle, show pipeline never drains + # ═══════════════════════════════════════════════════ + self.play( + tokens["C"].animate.move_to(slots[0]), # Decode → Cleanup + tokens["D"].animate.move_to(slots[3]), # Prefill → Decode + tokens["E"].animate.move_to(slots[2]), # Refill → Prefill + tokens["F"].animate.move_to(slots[1]), # → Refill + ) + self.wait(0.25) + + # C done → G enters + self.play(FadeOut(tokens["C"], scale=0.6)) + tokens["G"] = make_token("G", BATCH_COLORS[6]) + tokens["G"].move_to(states[1]).shift(RIGHT * 1.5) + self.play(FadeIn(tokens["G"], scale=0.7)) + self.wait(0.35) + + # drop note: constant throughput + flow_note = Text("Pipeline never drains — constant throughput", + font_size=15, color=GREEN).next_to(states, DOWN, buff=0.55) + self.play(Write(flow_note)) + self.wait(1.5) + self.play(FadeOut(flow_note)) + + # clear tokens + self.play(*[FadeOut(t) for t in tokens.values()]) + + # ═══════════════════════════════════════════════════ + # 6. Position-Grouped Decode highlight + # ═══════════════════════════════════════════════════ + # show multiple tokens grouped at Decode + d_pos = states[3].get_center() + d_tokens = [ + make_token("T" + str(i), BATCH_COLORS[i]) for i in range(4) + ] + positions = [ + d_pos + RIGHT * 1.2 + UP * 0.45, + d_pos + RIGHT * 1.2, + d_pos + RIGHT * 2.5 + UP * 0.45, + d_pos + RIGHT * 2.5, + ] + for i in range(4): + d_tokens[i].move_to(positions[i]) + self.play(FadeIn(d_tokens[i], scale=0.6), run_time=0.2) + + ring = SurroundingRectangle(states[3], color=YELLOW, buff=0.12, stroke_width=3) + ring_txt = Text( + "Position-Grouped Batching\nSame decode position → single matmul", + font_size=14, color=YELLOW, line_spacing=0.6, + ).next_to(states[3], DOWN, buff=0.5) + self.play(Create(ring), Write(ring_txt)) + self.wait(2.0) + self.play(FadeOut(ring), FadeOut(ring_txt), + *[FadeOut(t) for t in d_tokens]) + + # ═══════════════════════════════════════════════════ + # 7. O(1) Bitmask Slot Allocation + # ═══════════════════════════════════════════════════ + bitmask_title = Text("O(1) Slot Allocation via Bitmask", + font_size=22, color=ORANGE).next_to(states, DOWN, buff=0.75) + bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)", + font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15) + self.play(Write(bitmask_title), Write(bitmask_desc)) + self.wait(1.5) + + # animate bitmask bits flipping + bits_group = VGroup() + bit_size = 0.18 + for i in range(16): + square = Square(side_length=bit_size * 2, color=GRAY, + fill_opacity=0.0, stroke_width=1.2) + if i in (2, 5, 9, 13): + square.set_fill(GRAY, opacity=0.5) + bits_group.add(square) + bits_group.arrange(RIGHT, buff=0.06) + bits_group.next_to(bitmask_desc, DOWN, buff=0.3) + + occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4) + self.play(Create(bits_group), Write(occupied_lbl)) + + # flip to ~occupied + flipped = VGroup() + for i, sq in enumerate(bits_group): + copy_sq = Square(side_length=bit_size * 2, color=GRAY, + fill_opacity=0.0, stroke_width=1.2).move_to(sq) + if i not in (2, 5, 9, 13): + copy_sq.set_fill(GRAY, opacity=0.5) + flipped.add(copy_sq) + free_lbl = Text("free_slots", font_size=11, color=GREEN) \ + .next_to(flipped, LEFT, buff=0.4).align_to(occupied_lbl, LEFT) + + self.play(Transform(bits_group, flipped), + Transform(occupied_lbl, free_lbl)) self.wait(1.2) + self.play(FadeOut(bits_group), FadeOut(occupied_lbl), + FadeOut(bitmask_title), FadeOut(bitmask_desc)) - # ── clear for throughput ── - self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar]) + # ═══════════════════════════════════════════════════ + # 8. Throughput comparison with animated bars + # ═══════════════════════════════════════════════════ + self.play( + *[FadeOut(m) for m in self.mobjects if m is not title and m is not bar], + FadeOut(loop), FadeOut(loop_lbl), + ) + for s in states: + self.play(FadeOut(s), run_time=0.15) + for a in trans_arrows: + self.play(FadeOut(a), run_time=0.15) - compare = VGroup( - Text("Throughput Comparison", font_size=32, color=BLUE), - Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED), - Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN), - ).arrange(DOWN, buff=0.4, aligned_edge=LEFT) - self.play(Write(compare)) - self.wait(2) - self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar)) + self.wait(0.3) + + # bars + bar_base = LEFT * 2 + bar_max_w = 4.5 + bar_h = 0.5 + + static_bar = Rectangle(width=0, height=bar_h, color=RED, + fill_opacity=0.7, stroke_width=0).move_to(bar_base, LEFT) + static_value = always_redraw( + lambda: DecimalNumber( + static_bar.get_width() / bar_max_w * 1.0, + num_decimal_places=1, font_size=24, color=WHITE, + ).next_to(static_bar, RIGHT, buff=0.2) + ) + static_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.2) + static_rect.move_to(bar_base, LEFT).shift(RIGHT * bar_max_w / 2) + + cb_bar = Rectangle(width=0, height=bar_h, color=GREEN, + fill_opacity=0.7, stroke_width=0).move_to(bar_base, LEFT) + cb_value = always_redraw( + lambda: DecimalNumber( + cb_bar.get_width() / bar_max_w * 3.4, + num_decimal_places=1, font_size=24, color=WHITE, + ).next_to(cb_bar, RIGHT, buff=0.2) + ) + cb_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.2) + cb_rect.move_to(bar_base, LEFT).shift(RIGHT * bar_max_w / 2) + + static_label = Text("Static Batching ", font_size=22, color=RED) \ + .next_to(static_rect, LEFT, buff=0.3) + cb_label = Text("Continuous Batching", font_size=22, color=GREEN) \ + .next_to(cb_rect, LEFT, buff=0.3) + + labels = VGroup(static_label, cb_label).arrange(DOWN, buff=0.7, aligned_edge=LEFT) + labels.shift(LEFT * 2) + + bar_group = VGroup( + static_rect, static_bar, static_value, + cb_rect, cb_bar, cb_value, + static_label, cb_label, + ) + + compare_title = Text("Throughput", font_size=30, color=BLUE).next_to(bar_group, UP, buff=0.6) + + self.play(Write(compare_title)) + self.play(Create(static_rect), Create(cb_rect), + Write(static_label), Write(cb_label)) + + # grow static bar + self.play(GrowFromEdge(static_bar, LEFT), + rate_func=linear, run_time=0.6) + self.wait(0.3) + + # grow cb bar significantly faster + self.play(GrowFromEdge(cb_bar, LEFT), + rate_func=linear, run_time=0.6) + self.wait(0.4) + + # labels under bars + static_num = Text("1.0x", font_size=22, color=RED) \ + .next_to(static_rect, RIGHT, buff=0.3) + cb_num = Text("3.4x", font_size=22, color=GREEN) \ + .next_to(cb_rect, RIGHT, buff=0.3) + self.play(Write(static_num), Write(cb_num)) + self.wait(2.5) + + self.play(*[FadeOut(m) for m in self.mobjects])