diff --git a/continuous_batching.py b/continuous_batching.py index 73d4319..fda8ada 100644 --- a/continuous_batching.py +++ b/continuous_batching.py @@ -182,7 +182,7 @@ class ContinuousBatching(Scene): self.play(*[FadeOut(t) for t in tokens.values()]) # ═══════════════════════════════════════════════════ - # 6. Position-Grouped Decode highlight + # 7. Position-Grouped Decode highlight # ═══════════════════════════════════════════════════ # show multiple tokens grouped at Decode d_pos = states[3].get_center() @@ -210,7 +210,7 @@ class ContinuousBatching(Scene): *[FadeOut(t) for t in d_tokens]) # ═══════════════════════════════════════════════════ - # 7. O(1) Bitmask Slot Allocation + # 8. O(1) Bitmask Slot Allocation # ═══════════════════════════════════════════════════ bitmask_title = Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE).next_to(states, DOWN, buff=0.75) @@ -251,20 +251,214 @@ class ContinuousBatching(Scene): self.play(FadeOut(bits_group), FadeOut(occupied_lbl), FadeOut(bitmask_title), FadeOut(bitmask_desc)) + # ═══════════════════════════════════════════════════ - # 8. Throughput comparison with animated bars + # 9. Gantt timeline comparison — Static vs Continuous # ═══════════════════════════════════════════════════ self.play( *[FadeOut(m) for m in self.mobjects if m is not title and m is not bar], FadeOut(loop), FadeOut(loop_lbl), ) for s in states: - self.play(FadeOut(s), run_time=0.15) + self.play(FadeOut(s), run_time=0.10) for a in trans_arrows: - self.play(FadeOut(a), run_time=0.15) + self.play(FadeOut(a), run_time=0.10) + self.wait(0.2) + # ── layout constants ── + CELL = 0.44 # width per time tick + BH = 0.32 # bar height + BGAP = 0.10 # gap between rows + ROW = BH + BGAP # 0.42 — row pitch + TICKS = 12 # time columns + PANEL_W = TICKS * CELL # 5.28 + L_OX = -5.8 # left-panel origin x + R_OX = 1.0 # right-panel origin x + GY = 2.0 # gantt top y + + def gbox(ox, y, start, span, color, fill=0.75): + x = ox + start * CELL + w = span * CELL + return Rectangle( + width=w, height=BH, color=color, + fill_opacity=fill, stroke_width=0, + ).move_to([x + w / 2, y, 0]) + + def batch_box(ox, y_gpu, y_last_req, start, span, color, label_txt): + w = span * CELL + top = y_gpu + BH / 2 + 0.06 + bot = y_last_req - BH / 2 - 0.06 + h = top - bot + cx = ox + (start + span / 2) * CELL + cy = (top + bot) / 2 + rect = Rectangle( + width=w, height=h, color=color, + stroke_width=1.8, fill_opacity=0.04, + ) + rect.move_to([cx, cy, 0]) + lbl = Text(label_txt, font_size=12, color=color).next_to(rect, UP, buff=0.06) + return rect, lbl + + def taxis(ox, ty): + line = Line( + [ox, ty, 0], [ox + PANEL_W, ty, 0], + color=GRAY, stroke_width=1.2, + ) + ticks_vg = VGroup() + for t in range(TICKS + 1): + ti = Line(DOWN * 0.06, UP * 0.06, color=GRAY, stroke_width=0.8) + ti.move_to([ox + t * CELL, ty, 0]) + ticks_vg.add(ti) + nums_vg = VGroup() + for t in range(0, TICKS + 1, 3): + n = Text(str(t), font_size=11, color=GRAY).next_to( + [ox + t * CELL, ty, 0], DOWN, buff=0.10, + ) + nums_vg.add(n) + return VGroup(line, ticks_vg, nums_vg) + + # ── Left: Static Batching ── + s_title = Text("Static Batching", font_size=26, color=RED) + s_title.move_to([L_OX + PANEL_W / 2, GY + 0.65, 0]) + s_note = Text("requests wait → batch together → all run same length · GPU idle gaps", + font_size=13, color=RED) \ + .move_to([L_OX + PANEL_W / 2, -1.6, 0]) + self.play(Write(s_title)) + self.wait(0.25) + + st_axis = taxis(L_OX, GY) + self.play(Create(st_axis)) + + gpu_l = Text("GPU", font_size=14, color=WHITE) + gpu_l.move_to([L_OX - 0.55, GY - ROW, 0]) + self.play(Write(gpu_l)) + + # Static GPU: idle [0-2], batch 1 [2-6], batch 2 [6-10], idle [10-12] + s_y_gpu = GY - ROW + s_gpu_idle1 = gbox(L_OX, s_y_gpu, 0, 2, RED, 0.45) + s_gpu_batch1 = gbox(L_OX, s_y_gpu, 2, 4, GREEN) + s_gpu_batch2 = gbox(L_OX, s_y_gpu, 6, 4, GREEN) + s_gpu_idle2 = gbox(L_OX, s_y_gpu, 10, 2, RED, 0.45) + s_gpu_bars = [s_gpu_idle1, s_gpu_batch1, s_gpu_batch2, s_gpu_idle2] + for seg in s_gpu_bars: + self.play(GrowFromEdge(seg, LEFT), run_time=0.09) + + # IDLE labels over the red idle strips + s_idle1 = Text("IDLE", font_size=10, color=RED, weight=BOLD) \ + .move_to([L_OX + 1 * CELL, s_y_gpu, 0]) + s_idle2 = Text("IDLE", font_size=10, color=RED, weight=BOLD) \ + .move_to([L_OX + 11 * CELL, s_y_gpu, 0]) + self.play(Write(s_idle1), Write(s_idle2)) + + # Same 5 requests as continuous — but scheduled in batches + # each gets a gray WAIT bar before its coloured RUN bar + # (name, color, wait_start, wait_end, run_start, run_end) + s_req_defs = [ + ("A", ORANGE, 0, 2, 2, 6), # arrives t=0, waits for C → batch 1 + ("B", BLUE, 1, 2, 2, 6), # arrives t=1, waits for C + ("C", PINK, 2, 2, 2, 6), # arrives t=2, no wait (last to arrive) + ("D", ORANGE, 4, 6, 6, 10), # arrives t=4, waits for batch 1 to free GPU + ("E", BLUE, 6, 6, 6, 10), # arrives t=6, no wait (GPU just freed) + ] + s_bars = [] + for i, (name, col, ws, we, rs, re) in enumerate(s_req_defs): + y = s_y_gpu - (i + 1) * ROW + lbl = Text(f"Req {name}", font_size=12, color=col) + lbl.move_to([L_OX - 0.55, y, 0]) + items = [lbl] + anims = [FadeIn(lbl)] + if we - ws > 0.02: + wbar = gbox(L_OX, y, ws, we - ws, GRAY, 0.28) + items.append(wbar) + anims.append(GrowFromEdge(wbar, LEFT)) + rbar = gbox(L_OX, y, rs, re - rs, col, 0.60) + items.append(rbar) + anims.append(GrowFromEdge(rbar, LEFT)) + s_bars.extend(items) + self.play(*anims, run_time=0.09) + + # batch boxes — connect GPU busy segments to the requests they serve + s_y_last3 = s_y_gpu - 3 * ROW # Req C is the 3rd request row + s_y_last5 = s_y_gpu - 5 * ROW # Req E is the 5th request row + b1_rect, b1_lbl = batch_box(L_OX, s_y_gpu, s_y_last3, 2, 4, RED, "Batch 1") + b2_rect, b2_lbl = batch_box(L_OX, s_y_gpu, s_y_last5, 6, 4, RED, "Batch 2") + self.play(Create(b1_rect), Write(b1_lbl)) + self.play(Create(b2_rect), Write(b2_lbl)) + self.wait(0.8) + + # ── Right: Continuous Batching ── + c_title = Text("Continuous Batching", font_size=26, color=GREEN) + c_title.move_to([R_OX + PANEL_W / 2, GY + 0.65, 0]) + c_note = Text("no waiting · no padding · GPU never idle", + font_size=13, color=GREEN) \ + .move_to([R_OX + PANEL_W / 2, -1.6, 0]) + self.play(Write(c_title)) + self.wait(0.25) + + ct_axis = taxis(R_OX, GY) + self.play(Create(ct_axis)) + + c_y_gpu = GY - ROW + cgpu_l = Text("GPU", font_size=14, color=WHITE) + cgpu_l.move_to([R_OX - 0.55, c_y_gpu, 0]) + self.play(Write(cgpu_l)) + + # Continuous GPU: busy all 12 ticks (pipeline never drains) + c_gpu = gbox(R_OX, c_y_gpu, 0, 12, GREEN, 0.75) + self.play(GrowFromEdge(c_gpu, LEFT), run_time=0.5) + + # Same 5 requests — start immediately, no wait, staggered naturally + c_reqs = [ + ("A", ORANGE, 0, 4), + ("B", BLUE, 1, 4), + ("C", PINK, 2, 4), + ("D", ORANGE, 4, 4), + ("E", BLUE, 6, 4), + ] + c_bars = [] + c_n_reqs = len(c_reqs) + for i, (name, col, start, span) in enumerate(c_reqs): + y = c_y_gpu - (i + 1) * ROW + lbl = Text(f"Req {name}", font_size=12, color=col) + lbl.move_to([R_OX - 0.55, y, 0]) + bar_rect = gbox(R_OX, y, start, span, col, 0.60) + c_bars.extend([lbl, bar_rect]) + self.play(FadeIn(lbl), GrowFromEdge(bar_rect, LEFT), run_time=0.09) self.wait(0.3) + # continuous box — GPU always serving + c_y_last = c_y_gpu - c_n_reqs * ROW + c_box_rect, c_box_lbl = batch_box(R_OX, c_y_gpu, c_y_last, 0, 12, GREEN, "Always Serving") + self.play(Create(c_box_rect), Write(c_box_lbl)) + self.wait(1.0) + + # count annotation + s_count = Text("5 reqs · 2 batches · GPU idle gaps", + font_size=16, color=RED, weight=BOLD) \ + .next_to(s_gpu_batch1, DOWN, buff=1.0).align_to(s_gpu_batch1, LEFT) + c_count = Text("5 reqs · continuous · GPU never idle", + font_size=16, color=GREEN, weight=BOLD) \ + .next_to(c_gpu, DOWN, buff=1.0).align_to(c_gpu, LEFT) + self.play(Write(s_note), Write(c_note)) + self.wait(0.3) + self.play(Write(s_count), Write(c_count)) + self.wait(2.5) + self.play(FadeOut(s_count), FadeOut(c_count)) + + # ── Fade out gantt ── + gantt_mobs = [ + title, bar, s_title, s_note, c_title, c_note, + gpu_l, cgpu_l, s_idle1, s_idle2, st_axis, ct_axis, + *s_gpu_bars, c_gpu, *s_bars, *c_bars, + b1_rect, b1_lbl, b2_rect, b2_lbl, c_box_rect, c_box_lbl, + ] + self.play(*[FadeOut(m) for m in gantt_mobs]) + self.wait(0.2) + + # ═══════════════════════════════════════════════════ + # 10. Throughput comparison with animated bars + # ═══════════════════════════════════════════════════ + # ---- title ---- compare_title = Text("Throughput Comparison", font_size=30, color=BLUE) self.play(Write(compare_title)) @@ -277,28 +471,31 @@ class ContinuousBatching(Scene): bar_h = 0.55 row_gap = 0.8 - # ratio: static = baseline (1.0), continuous = 3.4x ratio = 1.0 / 3.4 # ---- Static Batching row ---- s_label = Text("Static Batching", font_size=24, color=RED) s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5) - s_bar = Rectangle(width=bar_max_w * ratio, height=bar_h, - color=RED, fill_opacity=0.55, stroke_width=0) + s_bar_rect = Rectangle( + width=bar_max_w * ratio, height=bar_h, + color=RED, fill_opacity=0.55, stroke_width=0, + ) s_num = Text("1.0x", font_size=24, color=RED) # ---- Continuous Batching row ---- c_label = Text("Continuous Batching", font_size=24, color=GREEN) c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5) - c_bar = Rectangle(width=bar_max_w, height=bar_h, - color=GREEN, fill_opacity=0.55, stroke_width=0) + c_bar_rect = Rectangle( + width=bar_max_w, height=bar_h, + color=GREEN, fill_opacity=0.55, stroke_width=0, + ) c_num = Text("3.4x", font_size=24, color=GREEN) # position rects first, then align bars s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2)) c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2)) - s_bar.align_to(s_rect, LEFT).align_to(s_rect, UP) - c_bar.align_to(c_rect, LEFT).align_to(c_rect, UP) + s_bar_rect.align_to(s_rect, LEFT).align_to(s_rect, UP) + c_bar_rect.align_to(c_rect, LEFT).align_to(c_rect, UP) # labels left, nums right s_label.next_to(s_rect, LEFT, buff=0.4) @@ -313,9 +510,9 @@ class ContinuousBatching(Scene): self.wait(0.3) # grow bars - self.play(GrowFromEdge(s_bar, LEFT), rate_func=linear, run_time=0.6) + self.play(GrowFromEdge(s_bar_rect, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) - self.play(GrowFromEdge(c_bar, LEFT), rate_func=linear, run_time=0.6) + self.play(GrowFromEdge(c_bar_rect, LEFT), rate_func=linear, run_time=0.6) self.wait(0.3) # show values