diff --git a/continuous_batching.py b/continuous_batching.py index 73d4319..7edd482 100644 --- a/continuous_batching.py +++ b/continuous_batching.py @@ -1,7 +1,13 @@ -"""AstrAI promo: Continuous Batching — state-machine driven batch rotation. +"""AstrAI promo: Continuous Batching — Static contrast → Queue → State-machine pipeline. -Shows a 4-state FSM (Cleanup → Refill → Prefill → Decode → Loop → Cleanup) -with coloured batch tokens flowing through states, entering & leaving continuously. +Sections: + 0. Title + 1. Problem: Static Batching — requests padded, serial, GPU idle + 2. Solution: Continuous Batching — pipeline + waiting queue + 3. State-machine rotation with queue feeding in (in/out flow) + 4. Position-Grouped Decode highlight + 5. O(1) Bitmask Slot Allocation + 6. Throughput comparison (animated bars) """ from manim import * @@ -21,15 +27,134 @@ class ContinuousBatching(Scene): # ═══════════════════════════════════════════════════ # 0. Title # ═══════════════════════════════════════════════════ - title = Text("Continuous Batching", font_size=48, color=BLUE) + title = Text("Batching", font_size=48, color=BLUE) self.play(Write(title)) - self.wait(0.4) + self.wait(0.5) self.play(title.animate.to_edge(UP).scale(0.55)) - bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15) - self.play(Create(bar)) + top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15) + self.play(Create(top_bar)) # ═══════════════════════════════════════════════════ - # 1. Build state-machine layout (vertical, 4 states) + # 1. Problem: Static Batching + # ═══════════════════════════════════════════════════ + subtitle = Text("Static Batching", font_size=32, color=RED).next_to(top_bar, DOWN, buff=0.5) + self.play(Write(subtitle)) + self.wait(0.4) + + # three requests with different lengths + req_w_base = 0.35 + req_h = 0.4 + req_gap = 0.6 + + reqs = [] + lengths = [8, 12, 5] # different token lengths + colors_req = [ORANGE, BLUE, PINK] + for i, (n, c) in enumerate(zip(lengths, colors_req)): + w = req_w_base * n + bar = Rectangle(width=w, height=req_h, color=c, fill_opacity=0.5, stroke_width=1.5) + lbl = Text(f"Req {i+1}", font_size=18, color=c) + grp = VGroup(bar, lbl) + lbl.next_to(bar, UP, buff=0.1) + reqs.append(grp) + + VGroup(*reqs).arrange(DOWN, buff=0.55).move_to(ORIGIN).shift(LEFT * 1.5 + DOWN * 0.2) + + for r in reqs: + self.play(FadeIn(r, scale=0.9), run_time=0.3) + self.wait(0.5) + + # label: different prompt lengths + diff_note = Text("Different prompt lengths", font_size=16, color=LIGHT_GRAY) \ + .next_to(reqs[2], DOWN, buff=0.4) + self.play(Write(diff_note)) + self.wait(0.8) + + # animate padding — stretch all bars to max length + max_w = req_w_base * max(lengths) + pad_lines = VGroup() + for r in reqs: + cur_w = r[0].get_width() + if cur_w < max_w: + pad = Rectangle(width=max_w - cur_w, height=req_h, + fill_opacity=0.25, stroke_width=0, color=GRAY) + pad.next_to(r[0], RIGHT, buff=0) + pad_lines.add(pad) + + self.play(*[Create(p) for p in pad_lines]) + pad_note = Text("Padded to longest → waste", + font_size=15, color=RED).next_to(diff_note, DOWN, buff=0.15) + self.play(Write(pad_note)) + self.wait(0.8) + + # clear notes + self.play(FadeOut(diff_note), FadeOut(pad_note), *[FadeOut(p) for p in pad_lines]) + + # move all reqs to centre and show serial processing with idle gaps + serial_slots = [ + ORIGIN + LEFT * 2.5 + DOWN * 1.0, + ORIGIN + DOWN * 1.0, + ORIGIN + RIGHT * 2.5 + DOWN * 1.0, + ] + batch_box = Rectangle(width=max_w + 0.4, height=req_h + 0.25, + color=RED, stroke_width=1.5, stroke_opacity=0.5) + + active_tex = Text("", font_size=14, color=WHITE) + + for i in range(3): + batch_box.move_to(serial_slots[i]) + reqs[i].move_to(batch_box.get_center()) + self.play( + reqs[i].animate.move_to(batch_box.get_center()).set_opacity(1), + Create(batch_box), + run_time=0.5, + ) + # show GPU active label + if i == 0: + gpu_active = Text("GPU Busy", font_size=14, color=GREEN) \ + .next_to(batch_box, UP, buff=0.15) + self.play(Write(gpu_active)) + self.wait(0.25 if i < 2 else 0.1) + + # idle gap between batches + if i < 2: + gpu_idle = Text("Idle...", font_size=13, color=GRAY) \ + .next_to(batch_box, RIGHT, buff=0.4).shift(UP * 0.1) + self.play(Write(gpu_idle)) + self.wait(0.6) + self.play(FadeOut(gpu_idle)) + else: + self.play(FadeOut(gpu_active)) + + self.play(FadeOut(batch_box)) + reqs[i].set_opacity(0.35) + + # big X over static batching + cross = VGroup( + Line(LEFT * 2 + UP * 1.5, RIGHT * 2 + DOWN * 1.5, color=RED, stroke_width=5), + Line(LEFT * 2 + DOWN * 1.5, RIGHT * 2 + UP * 1.5, color=RED, stroke_width=5), + ) + reason = Text("Serial → GPU idle 60%+ of the time", font_size=16, color=RED) \ + .next_to(cross, DOWN, buff=0.3) + self.play(Create(cross), Write(reason)) + self.wait(1.5) + + # clear everything for transition + self.play( + FadeOut(cross), FadeOut(reason), FadeOut(subtitle), + *[FadeOut(r) for r in reqs], + ) + + # ═══════════════════════════════════════════════════ + # 2. Transition: Continuous Batching — the solution + # ═══════════════════════════════════════════════════ + setattr(title, "became", title.become) # hack — update title + new_title = Text("Continuous Batching", font_size=48, color=BLUE).to_edge(UP) \ + .scale(0.55).move_to(title) + self.play(Transform(title, new_title)) + self.wait(0.5) + + # ═══════════════════════════════════════════════════ + # 3. Build state-machine + waiting queue # ═══════════════════════════════════════════════════ state_names = ["Cleanup", "Refill", "Prefill", "Decode"] @@ -44,7 +169,7 @@ class ContinuousBatching(Scene): states.add(VGroup(box, lbl)) states.arrange(DOWN, buff=0.3) - states.shift(LEFT * 3.8 + DOWN * 0.5) + states.shift(LEFT * 3.5 + DOWN * 0.4) for i in range(1, 4): a = Arrow( @@ -59,7 +184,7 @@ class ContinuousBatching(Scene): if i > 0: self.play(Create(trans_arrows[i - 1])) - # loop arrow — Decode returns to Cleanup (multiturn decoding) + # loop arrow loop = CurvedArrow( states[-1].get_right() + RIGHT * 0.2, states[0].get_right() + RIGHT * 0.2, @@ -67,47 +192,80 @@ class ContinuousBatching(Scene): ) loop_lbl = Text("per token", font_size=11, color=GRAY).next_to(loop, RIGHT, buff=0.08) self.play(Create(loop), Write(loop_lbl)) - self.wait(0.4) + self.wait(0.3) # ═══════════════════════════════════════════════════ - # 2. Boot tokens — initial batches placed at mid-cycle + # 4. Waiting Queue # ═══════════════════════════════════════════════════ - def make_token(name: str, col: str) -> VGroup: + queue_box = Rectangle(width=1.6, height=2.5, color=LIGHT_GRAY, + stroke_width=1.5, fill_opacity=0.06) + queue_label = Text("Queue", font_size=14, color=LIGHT_GRAY) + queue_label.next_to(queue_box, UP, buff=0.12) + queue_item = VGroup(queue_box, queue_label) + queue_item.next_to(states[1], RIGHT, buff=2.0) # Refill level + + self.play(Create(queue_box), Write(queue_label)) + + # pending requests in queue + def make_req_token(name, col): card = RoundedRectangle(width=0.65, height=0.38, corner_radius=0.08, color=col, fill_opacity=0.35, stroke_width=1.8) txt = Text(name, font_size=13, color=col) return VGroup(card, txt) + queue_entries = VGroup( + make_req_token("X", BATCH_COLORS[2]), + make_req_token("Y", BATCH_COLORS[1]), + make_req_token("Z", BATCH_COLORS[0]), + ) + queue_entries.arrange(DOWN, buff=0.25) + queue_entries.move_to(queue_box) + + for q in queue_entries: + self.play(FadeIn(q, scale=0.7), run_time=0.2) + + # arrow from queue into Refill + queue_arrow = Arrow( + queue_box.get_left() + LEFT * 0.05, + states[1].get_right() + RIGHT * 0.15, + color=ORANGE, buff=0.06, + max_tip_length_to_length_ratio=0.2, + ) + queue_arrow_lbl = Text("admit", font_size=11, color=ORANGE) \ + .next_to(queue_arrow, UP, buff=0.05) + self.play(Create(queue_arrow), Write(queue_arrow_lbl)) + self.wait(0.5) + + # ═══════════════════════════════════════════════════ + # 5. Token flow — pipeline + queue + # ═══════════════════════════════════════════════════ tokens = { - "A": make_token("A", BATCH_COLORS[0]), - "B": make_token("B", BATCH_COLORS[1]), - "C": make_token("C", BATCH_COLORS[2]), + "A": make_req_token("A", BATCH_COLORS[0]), + "B": make_req_token("B", BATCH_COLORS[1]), + "C": make_req_token("C", BATCH_COLORS[2]), } - # all three at consecutive stages, Prefill is the entry point - tokens["A"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill - tokens["B"].move_to(states[3]).shift(RIGHT * 1.5) # Decode - tokens["C"].move_to(states[0]).shift(RIGHT * 1.5) # Cleanup + tokens["A"].move_to(states[2]).shift(RIGHT * 1.2) # Prefill + tokens["B"].move_to(states[3]).shift(RIGHT * 1.2) # Decode + tokens["C"].move_to(states[0]).shift(RIGHT * 1.2) # Cleanup for t in tokens.values(): - self.play(FadeIn(t, scale=0.7), run_time=0.25) - self.wait(0.2) + self.play(FadeIn(t, scale=0.7), run_time=0.2) + self.wait(0.3) note = Text("Every request starts at Prefill", font_size=16, color=WHITE) \ - .next_to(states, DOWN, buff=0.55) + .next_to(states, DOWN, buff=0.6) self.play(Write(note)) - self.wait(1.0) + self.wait(0.8) self.play(FadeOut(note)) - # ═══════════════════════════════════════════════════ - # 3. Tick 1 — advance, C exits, new D enters at Prefill - # ═══════════════════════════════════════════════════ slots = [ - states[0].get_center() + RIGHT * 1.5, # Cleanup - states[1].get_center() + RIGHT * 1.5, # Refill - states[2].get_center() + RIGHT * 1.5, # Prefill - states[3].get_center() + RIGHT * 1.5, # Decode + states[0].get_center() + RIGHT * 1.2, # Cleanup + states[1].get_center() + RIGHT * 1.2, # Refill + states[2].get_center() + RIGHT * 1.2, # Prefill + states[3].get_center() + RIGHT * 1.2, # Decode ] + # ---- Tick 1: advance, C exits, X pulled from queue → Prefill ---- self.play( tokens["A"].animate.move_to(slots[3]), # Prefill → Decode tokens["B"].animate.move_to(slots[0]), # Decode → Cleanup @@ -115,17 +273,20 @@ class ContinuousBatching(Scene): ) self.wait(0.3) - # C (now at Refill) exits after completing the loop - # new D enters at Prefill + # C exits, X leaves queue → Prefill + q_first = queue_entries[0] self.play(FadeOut(tokens["C"], scale=0.6)) - tokens["D"] = make_token("D", BATCH_COLORS[3]) - tokens["D"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry - self.play(FadeIn(tokens["D"], scale=0.7)) - self.wait(0.25) + tokens["D"] = make_req_token("X", BATCH_COLORS[2]) + tokens["D"].move_to(states[2]).shift(RIGHT * 1.2) # Prefill + self.play( + FadeOut(q_first, scale=0.5), + FadeIn(tokens["D"], scale=0.7), + ) + queue_entries.remove(q_first) + queue_entries.arrange(DOWN, buff=0.25).move_to(queue_box) + self.wait(0.3) - # ═══════════════════════════════════════════════════ - # 4. Tick 2 — advance, B exits, new E enters at Prefill - # ═══════════════════════════════════════════════════ + # ---- Tick 2: advance, B exits, Y from queue → Prefill ---- self.play( tokens["D"].animate.move_to(slots[3]), # Prefill → Decode tokens["A"].animate.move_to(slots[0]), # Decode → Cleanup @@ -133,67 +294,74 @@ class ContinuousBatching(Scene): ) self.wait(0.3) + q_first = queue_entries[0] self.play(FadeOut(tokens["B"], scale=0.6)) - tokens["E"] = make_token("E", BATCH_COLORS[4]) - tokens["E"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry - self.play(FadeIn(tokens["E"], scale=0.7)) - self.wait(0.25) + tokens["E"] = make_req_token("Y", BATCH_COLORS[1]) + tokens["E"].move_to(states[2]).shift(RIGHT * 1.2) + self.play( + FadeOut(q_first, scale=0.5), + FadeIn(tokens["E"], scale=0.7), + ) + queue_entries.remove(q_first) + queue_entries.arrange(DOWN, buff=0.25).move_to(queue_box) + self.wait(0.3) - # ═══════════════════════════════════════════════════ - # 5. Tick 3 — advance, A exits, new F enters at Prefill - # ═══════════════════════════════════════════════════ + # ---- Tick 3: advance, A exits, Z from queue → Prefill ---- self.play( tokens["E"].animate.move_to(slots[3]), # Prefill → Decode tokens["D"].animate.move_to(slots[0]), # Decode → Cleanup tokens["A"].animate.move_to(slots[1]), # Cleanup → Refill ) - self.wait(0.25) + self.wait(0.3) + q_first = queue_entries[0] self.play(FadeOut(tokens["A"], scale=0.6)) - tokens["F"] = make_token("F", BATCH_COLORS[5]) - tokens["F"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry - self.play(FadeIn(tokens["F"], scale=0.7)) - self.wait(0.25) - - # ═══════════════════════════════════════════════════ - # 6. Tick 4 — advance, F exits, new G enters at Prefill - # ═══════════════════════════════════════════════════ + tokens["F"] = make_req_token("Z", BATCH_COLORS[2]) + tokens["F"].move_to(states[2]).shift(RIGHT * 1.2) self.play( - tokens["F"].animate.move_to(slots[3]), # Prefill → Decode - tokens["E"].animate.move_to(slots[0]), # Decode → Cleanup - tokens["D"].animate.move_to(slots[1]), # Cleanup → Refill + FadeOut(q_first, scale=0.5), + FadeIn(tokens["F"], scale=0.7), ) - self.wait(0.25) + queue_entries.remove(q_first) + # queue is now empty — show new request W arriving + self.wait(0.3) - self.play(FadeOut(tokens["D"], scale=0.6)) - tokens["G"] = make_token("G", BATCH_COLORS[6]) - tokens["G"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry - self.play(FadeIn(tokens["G"], scale=0.7)) - self.wait(0.35) + # a brand-new request W flies into the empty queue + w_token = make_req_token("W", BATCH_COLORS[4]) + w_token.next_to(queue_box, RIGHT, buff=0.8) + self.play(w_token.animate.move_to(queue_box), run_time=0.6) + queue_entries.add(w_token) + self.wait(0.2) - # drop note: constant throughput, all enter at Prefill - flow_note = Text("All requests enter at Prefill — pipeline never drains", - font_size=15, color=GREEN).next_to(states, DOWN, buff=0.55) + # note: continuous throughput + flow_note = Text( + "Pipeline never drains — new requests always entering", + font_size=15, color=GREEN, + ).next_to(states, DOWN, buff=0.6) self.play(Write(flow_note)) self.wait(1.5) self.play(FadeOut(flow_note)) - # clear tokens - self.play(*[FadeOut(t) for t in tokens.values()]) + # clean up + self.play( + *[FadeOut(t) for t in tokens.values()], + FadeOut(queue_arrow), FadeOut(queue_arrow_lbl), + FadeOut(queue_box), FadeOut(queue_label), + *[FadeOut(q) for q in queue_entries], + ) # ═══════════════════════════════════════════════════ # 6. Position-Grouped Decode highlight # ═══════════════════════════════════════════════════ - # show multiple tokens grouped at Decode d_pos = states[3].get_center() d_tokens = [ - make_token("T" + str(i), BATCH_COLORS[i]) for i in range(4) + make_req_token("T" + str(i), BATCH_COLORS[i]) for i in range(4) ] positions = [ - d_pos + RIGHT * 1.2 + UP * 0.45, - d_pos + RIGHT * 1.2, - d_pos + RIGHT * 2.5 + UP * 0.45, - d_pos + RIGHT * 2.5, + d_pos + RIGHT * 1.0 + UP * 0.45, + d_pos + RIGHT * 1.0, + d_pos + RIGHT * 2.3 + UP * 0.45, + d_pos + RIGHT * 2.3, ] for i in range(4): d_tokens[i].move_to(positions[i]) @@ -219,7 +387,6 @@ class ContinuousBatching(Scene): self.play(Write(bitmask_title), Write(bitmask_desc)) self.wait(1.5) - # animate bitmask bits flipping bits_group = VGroup() bit_size = 0.18 for i in range(16): @@ -234,7 +401,6 @@ class ContinuousBatching(Scene): occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4) self.play(Create(bits_group), Write(occupied_lbl)) - # flip to ~occupied flipped = VGroup() for i, sq in enumerate(bits_group): copy_sq = Square(side_length=bit_size * 2, color=GRAY, @@ -255,7 +421,7 @@ class ContinuousBatching(Scene): # 8. Throughput comparison with animated bars # ═══════════════════════════════════════════════════ self.play( - *[FadeOut(m) for m in self.mobjects if m is not title and m is not bar], + *[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar], FadeOut(loop), FadeOut(loop_lbl), ) for s in states: @@ -265,42 +431,34 @@ class ContinuousBatching(Scene): self.wait(0.3) - # ---- title ---- compare_title = Text("Throughput Comparison", font_size=30, color=BLUE) self.play(Write(compare_title)) self.wait(0.2) self.play(compare_title.animate.to_edge(UP).scale(0.55)) self.wait(0.2) - # ---- bar config ---- bar_max_w = 5.0 bar_h = 0.55 row_gap = 0.8 - - # ratio: static = baseline (1.0), continuous = 3.4x ratio = 1.0 / 3.4 - # ---- Static Batching row ---- s_label = Text("Static Batching", font_size=24, color=RED) s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5) s_bar = Rectangle(width=bar_max_w * ratio, height=bar_h, color=RED, fill_opacity=0.55, stroke_width=0) s_num = Text("1.0x", font_size=24, color=RED) - # ---- Continuous Batching row ---- c_label = Text("Continuous Batching", font_size=24, color=GREEN) c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5) c_bar = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, fill_opacity=0.55, stroke_width=0) c_num = Text("3.4x", font_size=24, color=GREEN) - # position rects first, then align bars s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2)) c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2)) s_bar.align_to(s_rect, LEFT).align_to(s_rect, UP) c_bar.align_to(c_rect, LEFT).align_to(c_rect, UP) - # labels left, nums right s_label.next_to(s_rect, LEFT, buff=0.4) c_label.next_to(c_rect, LEFT, buff=0.4) s_num.next_to(s_rect, RIGHT, buff=0.4) @@ -312,13 +470,11 @@ class ContinuousBatching(Scene): ) self.wait(0.3) - # grow bars self.play(GrowFromEdge(s_bar, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) self.play(GrowFromEdge(c_bar, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) - # show values self.play(Write(s_num), Write(c_num)) self.wait(2.5)