diff --git a/continuous_batching.py b/continuous_batching.py index e050161..1f1ec8d 100644 --- a/continuous_batching.py +++ b/continuous_batching.py @@ -8,8 +8,7 @@ from manim import * Text.set_default(font="Times New Roman") -# ── palette ── -PHASE_COLORS = { +PAL = { "Cleanup": GRAY, "Refill": ORANGE, "Prefill": BLUE, @@ -30,22 +29,26 @@ class ContinuousBatching(Scene): bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15) self.play(Create(bar)) - LANE_W, LANE_H = 3.8, 0.95 - X_P, X_A, X_F = -4.5, 0.0, 4.5 + # ── layout config ── + LANE_W, LANE_H = 2.6, 0.95 + X_P, X_Pr, X_D, X_F = -4.95, -1.65, 1.65, 4.95 YL = 0.3 - P_CLR, A_CLR, F_CLR = GRAY, BLUE, RED + P_CLR, F_CLR = GRAY, RED def lane(x, label, clr, sub): - box = RoundedRectangle(width=LANE_W, height=LANE_H, corner_radius=0.12, - color=clr, fill_opacity=0.10, stroke_width=2.2) + box = RoundedRectangle( + width=LANE_W, height=LANE_H, corner_radius=0.12, + color=clr, fill_opacity=0.10, stroke_width=2.2, + ) t = Text(label, font_size=20, color=clr) s = Text(sub, font_size=10, color=LIGHT_GRAY) inner = VGroup(t, s).arrange(DOWN, buff=0.04).move_to(box) return VGroup(box, inner).move_to([x, YL, 0]) + # ── FSM label (unchanged) ── fsm_states = VGroup() for label, clr in [("Refill", ORANGE), ("→", LIGHT_GRAY), - ("Prefill", BLUE), ("→", LIGHT_GRAY), + ("Prefill", BLUE), ("→", LIGHT_GRAY), ("Decode", YELLOW), ("→", LIGHT_GRAY), ("Cleanup", GRAY)]: t = Text(label, font_size=13, color=clr) @@ -53,30 +56,31 @@ class ContinuousBatching(Scene): fsm_states.arrange(RIGHT, buff=0.06) fsm_states.next_to(bar, DOWN, buff=0.3) - pend_lane = lane(X_P, "PENDING", P_CLR, "waiting queue") - act_lane = lane(X_A, "ACTIVE", A_CLR, "Prefill") - fin_lane = lane(X_F, "FINISHED", F_CLR, "sequence done") - lane_group = VGroup(pend_lane, act_lane, fin_lane) + # ── 4-state pipeline lanes ── + pend_lane = lane(X_P, "PENDING", P_CLR, "waiting queue") + pref_lane = lane(X_Pr, "PREFILL", BLUE, "first token") + dec_lane = lane(X_D, "DECODE", YELLOW, "per-token gen") + fin_lane = lane(X_F, "FINISHED", F_CLR, "sequence done") + lane_group = VGroup(pend_lane, pref_lane, dec_lane, fin_lane) + # ── arrows ── ea = Arrow(pend_lane.get_left() + LEFT * 0.9, pend_lane.get_left(), color=GREEN, stroke_width=2.5, max_tip_length_to_length_ratio=0.15) el = Text("New Req", font_size=11, color=GREEN) el.next_to(ea, UP, buff=0.04) - ca = Arrow(act_lane.get_right(), fin_lane.get_left(), - color=GRAY, buff=0.06, - max_tip_length_to_length_ratio=0.15) - cl = Text("Cleanup", font_size=10, color=GRAY) - cl.next_to(ca, UP, buff=0.04) - - ra = Arrow(pend_lane.get_right(), act_lane.get_left(), + ra = Arrow(pend_lane.get_right(), pref_lane.get_left(), color=ORANGE, buff=0.06, max_tip_length_to_length_ratio=0.15) rl = Text("Refill", font_size=10, color=ORANGE) rl.next_to(ra, UP, buff=0.04) - ca = Arrow(act_lane.get_right(), fin_lane.get_left(), + ta = Arrow(pref_lane.get_right(), dec_lane.get_left(), + color=LIGHT_GRAY, buff=0.06, + max_tip_length_to_length_ratio=0.15) + + ca = Arrow(dec_lane.get_right(), fin_lane.get_left(), color=GRAY, buff=0.06, max_tip_length_to_length_ratio=0.15) cl = Text("Cleanup", font_size=10, color=GRAY) @@ -89,10 +93,11 @@ class ContinuousBatching(Scene): xl.next_to(xa, UP, buff=0.04) self.play(Write(fsm_states)) - self.play(Create(pend_lane), Create(act_lane), Create(fin_lane)) + self.play(Create(pend_lane), Create(pref_lane), Create(dec_lane), Create(fin_lane)) self.wait(0.3) self.play(Create(ea), Write(el), Create(ra), Write(rl), + Create(ta), Create(ca), Write(cl), Create(xa), Write(xl)) self.wait(0.5) @@ -101,8 +106,10 @@ class ContinuousBatching(Scene): TOK_W, TOK_H = 0.58, 0.38 def mk_tok(name, col, state, n_tok): - card = RoundedRectangle(width=TOK_W, height=TOK_H, corner_radius=0.06, - color=col, fill_opacity=0.38, stroke_width=1.6) + card = RoundedRectangle( + width=TOK_W, height=TOK_H, corner_radius=0.06, + color=col, fill_opacity=0.38, stroke_width=1.6, + ) t = Text(name, font_size=13, color=col).move_to(card) info = Text(f"{state} {n_tok}t", font_size=7, color=col) return VGroup(VGroup(card, t), info).arrange(DOWN, buff=0.03) @@ -112,22 +119,23 @@ class ContinuousBatching(Scene): sx = x - (n - 1) * sp / 2 return [np.array([sx + i * sp, -1.5, 0]) for i in range(n)] - P_SLOTS = slots(X_P, 3) - A_SLOTS = slots(X_A, 3) - F_SLOTS = slots(X_F, 2) + P_SLOTS = slots(X_P, 2) # G, F + Pr_SLOTS = slots(X_Pr, 1) # E + D_SLOTS = slots(X_D, 3) # D, A, B + F_SLOTS = slots(X_F, 1) # C tok = {} def add(name, col, lane_slots, idx, state, n): t = mk_tok(name, col, state, n).move_to(lane_slots[idx]) tok[name] = t - add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0) - add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0) - add("E", BATCH_COLORS[4], P_SLOTS, 2, "PENDING", 0) - add("D", BATCH_COLORS[3], A_SLOTS, 0, "DECODE", 5) - add("A", BATCH_COLORS[0], A_SLOTS, 1, "DECODE", 9) - add("B", BATCH_COLORS[1], A_SLOTS, 2, "DECODE", 13) - add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16) + add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0) + add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0) + add("E", BATCH_COLORS[4], Pr_SLOTS, 0, "PREFILL", 128) + add("D", BATCH_COLORS[3], D_SLOTS, 0, "DECODE", 5) + add("A", BATCH_COLORS[0], D_SLOTS, 1, "DECODE", 9) + add("B", BATCH_COLORS[1], D_SLOTS, 2, "DECODE", 13) + add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16) for t in tok.values(): self.play(FadeIn(t, scale=0.7), run_time=0.18) @@ -138,12 +146,11 @@ class ContinuousBatching(Scene): # ═══════════════════════════════════════════════════ # 7. Position-Grouped Decode highlight # ═══════════════════════════════════════════════════ - # show multiple tokens grouped at Decode - ring = SurroundingRectangle(act_lane, color=YELLOW, buff=0.12, stroke_width=3) + ring = SurroundingRectangle(dec_lane, color=YELLOW, buff=0.12, stroke_width=3) ring_txt = Text( "Position-Grouped Batching\nSame decode position → single matmul", font_size=14, color=YELLOW, line_spacing=0.6, - ).next_to(act_lane, DOWN, buff=0.5) + ).next_to(dec_lane, DOWN, buff=0.5) self.play(Create(ring), Write(ring_txt)) self.wait(2.0) self.play(FadeOut(ring), FadeOut(ring_txt)) @@ -151,33 +158,40 @@ class ContinuousBatching(Scene): # ═══════════════════════════════════════════════════ # 8. O(1) Bitmask Slot Allocation # ═══════════════════════════════════════════════════ - bitmask_title = Text("O(1) Slot Allocation via Bitmask", - font_size=22, color=ORANGE).next_to(lane_group, DOWN, buff=0.75) - bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)", - font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15) + bitmask_title = Text( + "O(1) Slot Allocation via Bitmask", + font_size=22, color=ORANGE, + ).next_to(lane_group, DOWN, buff=0.75) + bitmask_desc = Text( + "free_slots = ~occupied_mask (one-clock op)", + font_size=15, color=GRAY, + ).next_to(bitmask_title, DOWN, buff=0.15) self.play(Write(bitmask_title), Write(bitmask_desc)) self.wait(1.5) - # animate bitmask bits flipping bits_group = VGroup() bit_size = 0.18 for i in range(16): - square = Square(side_length=bit_size * 2, color=GRAY, - fill_opacity=0.0, stroke_width=1.2) + square = Square( + side_length=bit_size * 2, color=GRAY, + fill_opacity=0.0, stroke_width=1.2, + ) if i in (2, 5, 9, 13): square.set_fill(GRAY, opacity=0.5) bits_group.add(square) bits_group.arrange(RIGHT, buff=0.06) bits_group.next_to(bitmask_desc, DOWN, buff=0.3) - occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4) + occupied_lbl = Text("occupied_mask", font_size=11, color=RED) \ + .next_to(bits_group, LEFT, buff=0.4) self.play(Create(bits_group), Write(occupied_lbl)) - # flip to ~occupied flipped = VGroup() for i, sq in enumerate(bits_group): - copy_sq = Square(side_length=bit_size * 2, color=GRAY, - fill_opacity=0.0, stroke_width=1.2).move_to(sq) + copy_sq = Square( + side_length=bit_size * 2, color=GRAY, + fill_opacity=0.0, stroke_width=1.2, + ).move_to(sq) if i not in (2, 5, 9, 13): copy_sq.set_fill(GRAY, opacity=0.5) flipped.add(copy_sq) @@ -190,7 +204,6 @@ class ContinuousBatching(Scene): self.play(FadeOut(bits_group), FadeOut(occupied_lbl), FadeOut(bitmask_title), FadeOut(bitmask_desc)) - # ═══════════════════════════════════════════════════ # 9. Gantt timeline comparison — Static vs Continuous # ═══════════════════════════════════════════════════ @@ -199,16 +212,15 @@ class ContinuousBatching(Scene): ) self.wait(0.2) - # ── layout constants ── - CELL = 0.44 # width per time tick - BH = 0.32 # bar height - BGAP = 0.10 # gap between rows - ROW = BH + BGAP # 0.42 — row pitch - TICKS = 12 # time columns - PANEL_W = TICKS * CELL # 5.28 - L_OX = -5.8 # left-panel origin x - R_OX = 1.0 # right-panel origin x - GY = 2.0 # gantt top y + CELL = 0.44 + BH = 0.32 + BGAP = 0.10 + ROW = BH + BGAP + TICKS = 12 + PANEL_W = TICKS * CELL + L_OX = -5.8 + R_OX = 1.0 + GY = 2.0 def gbox(ox, y, start, span, color, fill=0.75): x = ox + start * CELL @@ -254,9 +266,10 @@ class ContinuousBatching(Scene): # ── Left: Static Batching ── s_title = Text("Static Batching", font_size=26, color=RED) s_title.move_to([L_OX + PANEL_W / 2, GY + 0.65, 0]) - s_note = Text("requests wait → batch together → all run same length · GPU idle gaps", - font_size=13, color=RED) \ - .move_to([L_OX + PANEL_W / 2, -1.6, 0]) + s_note = Text( + "requests wait → batch together → all run same length · GPU idle gaps", + font_size=13, color=RED, + ).move_to([L_OX + PANEL_W / 2, -1.6, 0]) self.play(Write(s_title)) self.wait(0.25) @@ -267,7 +280,6 @@ class ContinuousBatching(Scene): gpu_l.move_to([L_OX - 0.55, GY - ROW, 0]) self.play(Write(gpu_l)) - # Static GPU: idle [0-2], batch 1 [2-6], idle [6-8], batch 2 [8-12] s_y_gpu = GY - ROW s_gpu_idle1 = gbox(L_OX, s_y_gpu, 0, 2, RED, 0.45) s_gpu_batch1 = gbox(L_OX, s_y_gpu, 2, 4, GREEN) @@ -277,23 +289,19 @@ class ContinuousBatching(Scene): for seg in s_gpu_bars: self.play(GrowFromEdge(seg, LEFT), run_time=0.09) - # IDLE labels over the red idle strips s_idle1 = Text("IDLE", font_size=10, color=RED) \ .move_to([L_OX + 1 * CELL, s_y_gpu, 0]) s_idle2 = Text("IDLE", font_size=10, color=RED) \ .move_to([L_OX + 7 * CELL, s_y_gpu, 0]) self.play(Write(s_idle1), Write(s_idle2)) - # Same 6 requests as continuous — but scheduled in batches - # D, E, F grouped into one batch (gated by F's arrival at t=8) - # (name, color, wait_start, wait_end, run_start, run_end) s_req_defs = [ - ("A", ORANGE, 0, 2, 2, 6), # arrives t=0, waits for C → batch 1 - ("B", BLUE, 1, 2, 2, 6), # arrives t=1, waits for C - ("C", PINK, 2, 2, 2, 6), # arrives t=2, no wait - ("D", ORANGE, 4, 8, 8, 12), # arrives t=4, waits for F → batch 2 - ("E", BLUE, 6, 8, 8, 12), # arrives t=6, waits for F - ("F", PINK, 8, 8, 8, 12), # arrives t=8, no wait + ("A", ORANGE, 0, 2, 2, 6), + ("B", BLUE, 1, 2, 2, 6), + ("C", PINK, 2, 2, 2, 6), + ("D", ORANGE, 4, 8, 8, 12), + ("E", BLUE, 6, 8, 8, 12), + ("F", PINK, 8, 8, 8, 12), ] s_bars = [] for i, (name, col, ws, we, rs, re) in enumerate(s_req_defs): @@ -312,9 +320,8 @@ class ContinuousBatching(Scene): s_bars.extend(items) self.play(*anims, run_time=0.09) - # batch boxes — connect GPU busy segments to the requests they serve - s_y_last3 = s_y_gpu - 3 * ROW # Req C is the 3rd request row - s_y_last6 = s_y_gpu - 6 * ROW # Req F is the 6th request row + s_y_last3 = s_y_gpu - 3 * ROW + s_y_last6 = s_y_gpu - 6 * ROW b1_rect, b1_lbl = batch_box(L_OX, s_y_gpu, s_y_last3, 2, 4, RED, "Batch 1") b2_rect, b2_lbl = batch_box(L_OX, s_y_gpu, s_y_last6, 8, 4, RED, "Batch 2") self.play(Create(b1_rect), Write(b1_lbl)) @@ -324,9 +331,10 @@ class ContinuousBatching(Scene): # ── Right: Continuous Batching ── c_title = Text("Continuous Batching", font_size=26, color=GREEN) c_title.move_to([R_OX + PANEL_W / 2, GY + 0.65, 0]) - c_note = Text("no waiting · no padding · GPU never idle", - font_size=13, color=GREEN) \ - .move_to([R_OX + PANEL_W / 2, -1.6, 0]) + c_note = Text( + "no waiting · no padding · GPU never idle", + font_size=13, color=GREEN, + ).move_to([R_OX + PANEL_W / 2, -1.6, 0]) self.play(Write(c_title)) self.wait(0.25) @@ -338,11 +346,9 @@ class ContinuousBatching(Scene): cgpu_l.move_to([R_OX - 0.55, c_y_gpu, 0]) self.play(Write(cgpu_l)) - # Continuous GPU: busy all 12 ticks (pipeline never drains) c_gpu = gbox(R_OX, c_y_gpu, 0, 12, GREEN, 0.75) self.play(GrowFromEdge(c_gpu, LEFT), run_time=0.5) - # Same 6 requests — start immediately, no wait, staggered naturally c_reqs = [ ("A", ORANGE, 0, 4), ("B", BLUE, 1, 4), @@ -362,13 +368,11 @@ class ContinuousBatching(Scene): self.play(FadeIn(lbl), GrowFromEdge(bar_rect, LEFT), run_time=0.09) self.wait(0.3) - # continuous box — GPU always serving c_y_last = c_y_gpu - c_n_reqs * ROW c_box_rect, c_box_lbl = batch_box(R_OX, c_y_gpu, c_y_last, 0, 12, GREEN, "Always Serving") self.play(Create(c_box_rect), Write(c_box_lbl)) self.wait(1.0) - # count annotation s_count = Text("6 reqs · 2 batches · GPU idle gaps", font_size=16, color=RED) \ .next_to(s_gpu_batch1, DOWN, buff=1.0).align_to(s_gpu_batch1, LEFT) @@ -381,7 +385,6 @@ class ContinuousBatching(Scene): self.wait(2.5) self.play(FadeOut(s_count), FadeOut(c_count)) - # ── Fade out gantt ── gantt_mobs = [ title, bar, s_title, s_note, c_title, c_note, gpu_l, cgpu_l, s_idle1, s_idle2, st_axis, ct_axis, @@ -394,22 +397,17 @@ class ContinuousBatching(Scene): # ═══════════════════════════════════════════════════ # 10. Throughput comparison with animated bars # ═══════════════════════════════════════════════════ - - # ---- title ---- compare_title = Text("Throughput Comparison", font_size=30, color=BLUE) self.play(Write(compare_title)) self.wait(0.2) self.play(compare_title.animate.to_edge(UP).scale(0.55)) self.wait(0.2) - # ---- bar config ---- bar_max_w = 5.0 bar_h = 0.55 row_gap = 0.8 - ratio = 1.0 / 3.4 - # ---- Static Batching row ---- s_label = Text("Static Batching", font_size=24, color=RED) s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5) s_bar_rect = Rectangle( @@ -418,7 +416,6 @@ class ContinuousBatching(Scene): ) s_num = Text("1.0x", font_size=24, color=RED) - # ---- Continuous Batching row ---- c_label = Text("Continuous Batching", font_size=24, color=GREEN) c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5) c_bar_rect = Rectangle( @@ -427,13 +424,11 @@ class ContinuousBatching(Scene): ) c_num = Text("3.4x", font_size=24, color=GREEN) - # position rects first, then align bars s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2)) c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2)) s_bar_rect.align_to(s_rect, LEFT).align_to(s_rect, UP) c_bar_rect.align_to(c_rect, LEFT).align_to(c_rect, UP) - # labels left, nums right s_label.next_to(s_rect, LEFT, buff=0.4) c_label.next_to(c_rect, LEFT, buff=0.4) s_num.next_to(s_rect, RIGHT, buff=0.4) @@ -445,13 +440,11 @@ class ContinuousBatching(Scene): ) self.wait(0.3) - # grow bars self.play(GrowFromEdge(s_bar_rect, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) self.play(GrowFromEdge(c_bar_rect, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) - # show values self.play(Write(s_num), Write(c_num)) self.wait(2.5)