refactor continuous_batching: 4-lane pipeline (PENDING/PREFILL/DECODE/FINISHED), remove bogus Trans arrow, Refill=admission per AstrAI arch

This commit is contained in:
ViperEkura 2026-05-18 15:29:40 +08:00
parent 4f14d09fe3
commit 12d587aa92
1 changed files with 86 additions and 93 deletions

View File

@ -8,8 +8,7 @@ from manim import *
Text.set_default(font="Times New Roman") Text.set_default(font="Times New Roman")
# ── palette ── PAL = {
PHASE_COLORS = {
"Cleanup": GRAY, "Cleanup": GRAY,
"Refill": ORANGE, "Refill": ORANGE,
"Prefill": BLUE, "Prefill": BLUE,
@ -30,22 +29,26 @@ class ContinuousBatching(Scene):
bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15) bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15)
self.play(Create(bar)) self.play(Create(bar))
LANE_W, LANE_H = 3.8, 0.95 # ── layout config ──
X_P, X_A, X_F = -4.5, 0.0, 4.5 LANE_W, LANE_H = 2.6, 0.95
X_P, X_Pr, X_D, X_F = -4.95, -1.65, 1.65, 4.95
YL = 0.3 YL = 0.3
P_CLR, A_CLR, F_CLR = GRAY, BLUE, RED P_CLR, F_CLR = GRAY, RED
def lane(x, label, clr, sub): def lane(x, label, clr, sub):
box = RoundedRectangle(width=LANE_W, height=LANE_H, corner_radius=0.12, box = RoundedRectangle(
color=clr, fill_opacity=0.10, stroke_width=2.2) width=LANE_W, height=LANE_H, corner_radius=0.12,
color=clr, fill_opacity=0.10, stroke_width=2.2,
)
t = Text(label, font_size=20, color=clr) t = Text(label, font_size=20, color=clr)
s = Text(sub, font_size=10, color=LIGHT_GRAY) s = Text(sub, font_size=10, color=LIGHT_GRAY)
inner = VGroup(t, s).arrange(DOWN, buff=0.04).move_to(box) inner = VGroup(t, s).arrange(DOWN, buff=0.04).move_to(box)
return VGroup(box, inner).move_to([x, YL, 0]) return VGroup(box, inner).move_to([x, YL, 0])
# ── FSM label (unchanged) ──
fsm_states = VGroup() fsm_states = VGroup()
for label, clr in [("Refill", ORANGE), ("", LIGHT_GRAY), for label, clr in [("Refill", ORANGE), ("", LIGHT_GRAY),
("Prefill", BLUE), ("", LIGHT_GRAY), ("Prefill", BLUE), ("", LIGHT_GRAY),
("Decode", YELLOW), ("", LIGHT_GRAY), ("Decode", YELLOW), ("", LIGHT_GRAY),
("Cleanup", GRAY)]: ("Cleanup", GRAY)]:
t = Text(label, font_size=13, color=clr) t = Text(label, font_size=13, color=clr)
@ -53,30 +56,31 @@ class ContinuousBatching(Scene):
fsm_states.arrange(RIGHT, buff=0.06) fsm_states.arrange(RIGHT, buff=0.06)
fsm_states.next_to(bar, DOWN, buff=0.3) fsm_states.next_to(bar, DOWN, buff=0.3)
pend_lane = lane(X_P, "PENDING", P_CLR, "waiting queue") # ── 4-state pipeline lanes ──
act_lane = lane(X_A, "ACTIVE", A_CLR, "Prefill") pend_lane = lane(X_P, "PENDING", P_CLR, "waiting queue")
fin_lane = lane(X_F, "FINISHED", F_CLR, "sequence done") pref_lane = lane(X_Pr, "PREFILL", BLUE, "first token")
lane_group = VGroup(pend_lane, act_lane, fin_lane) dec_lane = lane(X_D, "DECODE", YELLOW, "per-token gen")
fin_lane = lane(X_F, "FINISHED", F_CLR, "sequence done")
lane_group = VGroup(pend_lane, pref_lane, dec_lane, fin_lane)
# ── arrows ──
ea = Arrow(pend_lane.get_left() + LEFT * 0.9, pend_lane.get_left(), ea = Arrow(pend_lane.get_left() + LEFT * 0.9, pend_lane.get_left(),
color=GREEN, stroke_width=2.5, color=GREEN, stroke_width=2.5,
max_tip_length_to_length_ratio=0.15) max_tip_length_to_length_ratio=0.15)
el = Text("New Req", font_size=11, color=GREEN) el = Text("New Req", font_size=11, color=GREEN)
el.next_to(ea, UP, buff=0.04) el.next_to(ea, UP, buff=0.04)
ca = Arrow(act_lane.get_right(), fin_lane.get_left(), ra = Arrow(pend_lane.get_right(), pref_lane.get_left(),
color=GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.15)
cl = Text("Cleanup", font_size=10, color=GRAY)
cl.next_to(ca, UP, buff=0.04)
ra = Arrow(pend_lane.get_right(), act_lane.get_left(),
color=ORANGE, buff=0.06, color=ORANGE, buff=0.06,
max_tip_length_to_length_ratio=0.15) max_tip_length_to_length_ratio=0.15)
rl = Text("Refill", font_size=10, color=ORANGE) rl = Text("Refill", font_size=10, color=ORANGE)
rl.next_to(ra, UP, buff=0.04) rl.next_to(ra, UP, buff=0.04)
ca = Arrow(act_lane.get_right(), fin_lane.get_left(), ta = Arrow(pref_lane.get_right(), dec_lane.get_left(),
color=LIGHT_GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.15)
ca = Arrow(dec_lane.get_right(), fin_lane.get_left(),
color=GRAY, buff=0.06, color=GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.15) max_tip_length_to_length_ratio=0.15)
cl = Text("Cleanup", font_size=10, color=GRAY) cl = Text("Cleanup", font_size=10, color=GRAY)
@ -89,10 +93,11 @@ class ContinuousBatching(Scene):
xl.next_to(xa, UP, buff=0.04) xl.next_to(xa, UP, buff=0.04)
self.play(Write(fsm_states)) self.play(Write(fsm_states))
self.play(Create(pend_lane), Create(act_lane), Create(fin_lane)) self.play(Create(pend_lane), Create(pref_lane), Create(dec_lane), Create(fin_lane))
self.wait(0.3) self.wait(0.3)
self.play(Create(ea), Write(el), self.play(Create(ea), Write(el),
Create(ra), Write(rl), Create(ra), Write(rl),
Create(ta),
Create(ca), Write(cl), Create(ca), Write(cl),
Create(xa), Write(xl)) Create(xa), Write(xl))
self.wait(0.5) self.wait(0.5)
@ -101,8 +106,10 @@ class ContinuousBatching(Scene):
TOK_W, TOK_H = 0.58, 0.38 TOK_W, TOK_H = 0.58, 0.38
def mk_tok(name, col, state, n_tok): def mk_tok(name, col, state, n_tok):
card = RoundedRectangle(width=TOK_W, height=TOK_H, corner_radius=0.06, card = RoundedRectangle(
color=col, fill_opacity=0.38, stroke_width=1.6) width=TOK_W, height=TOK_H, corner_radius=0.06,
color=col, fill_opacity=0.38, stroke_width=1.6,
)
t = Text(name, font_size=13, color=col).move_to(card) t = Text(name, font_size=13, color=col).move_to(card)
info = Text(f"{state} {n_tok}t", font_size=7, color=col) info = Text(f"{state} {n_tok}t", font_size=7, color=col)
return VGroup(VGroup(card, t), info).arrange(DOWN, buff=0.03) return VGroup(VGroup(card, t), info).arrange(DOWN, buff=0.03)
@ -112,22 +119,23 @@ class ContinuousBatching(Scene):
sx = x - (n - 1) * sp / 2 sx = x - (n - 1) * sp / 2
return [np.array([sx + i * sp, -1.5, 0]) for i in range(n)] return [np.array([sx + i * sp, -1.5, 0]) for i in range(n)]
P_SLOTS = slots(X_P, 3) P_SLOTS = slots(X_P, 2) # G, F
A_SLOTS = slots(X_A, 3) Pr_SLOTS = slots(X_Pr, 1) # E
F_SLOTS = slots(X_F, 2) D_SLOTS = slots(X_D, 3) # D, A, B
F_SLOTS = slots(X_F, 1) # C
tok = {} tok = {}
def add(name, col, lane_slots, idx, state, n): def add(name, col, lane_slots, idx, state, n):
t = mk_tok(name, col, state, n).move_to(lane_slots[idx]) t = mk_tok(name, col, state, n).move_to(lane_slots[idx])
tok[name] = t tok[name] = t
add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0) add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0)
add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0) add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0)
add("E", BATCH_COLORS[4], P_SLOTS, 2, "PENDING", 0) add("E", BATCH_COLORS[4], Pr_SLOTS, 0, "PREFILL", 128)
add("D", BATCH_COLORS[3], A_SLOTS, 0, "DECODE", 5) add("D", BATCH_COLORS[3], D_SLOTS, 0, "DECODE", 5)
add("A", BATCH_COLORS[0], A_SLOTS, 1, "DECODE", 9) add("A", BATCH_COLORS[0], D_SLOTS, 1, "DECODE", 9)
add("B", BATCH_COLORS[1], A_SLOTS, 2, "DECODE", 13) add("B", BATCH_COLORS[1], D_SLOTS, 2, "DECODE", 13)
add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16) add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16)
for t in tok.values(): for t in tok.values():
self.play(FadeIn(t, scale=0.7), run_time=0.18) self.play(FadeIn(t, scale=0.7), run_time=0.18)
@ -138,12 +146,11 @@ class ContinuousBatching(Scene):
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# 7. Position-Grouped Decode highlight # 7. Position-Grouped Decode highlight
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# show multiple tokens grouped at Decode ring = SurroundingRectangle(dec_lane, color=YELLOW, buff=0.12, stroke_width=3)
ring = SurroundingRectangle(act_lane, color=YELLOW, buff=0.12, stroke_width=3)
ring_txt = Text( ring_txt = Text(
"Position-Grouped Batching\nSame decode position → single matmul", "Position-Grouped Batching\nSame decode position → single matmul",
font_size=14, color=YELLOW, line_spacing=0.6, font_size=14, color=YELLOW, line_spacing=0.6,
).next_to(act_lane, DOWN, buff=0.5) ).next_to(dec_lane, DOWN, buff=0.5)
self.play(Create(ring), Write(ring_txt)) self.play(Create(ring), Write(ring_txt))
self.wait(2.0) self.wait(2.0)
self.play(FadeOut(ring), FadeOut(ring_txt)) self.play(FadeOut(ring), FadeOut(ring_txt))
@ -151,33 +158,40 @@ class ContinuousBatching(Scene):
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# 8. O(1) Bitmask Slot Allocation # 8. O(1) Bitmask Slot Allocation
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
bitmask_title = Text("O(1) Slot Allocation via Bitmask", bitmask_title = Text(
font_size=22, color=ORANGE).next_to(lane_group, DOWN, buff=0.75) "O(1) Slot Allocation via Bitmask",
bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)", font_size=22, color=ORANGE,
font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15) ).next_to(lane_group, DOWN, buff=0.75)
bitmask_desc = Text(
"free_slots = ~occupied_mask (one-clock op)",
font_size=15, color=GRAY,
).next_to(bitmask_title, DOWN, buff=0.15)
self.play(Write(bitmask_title), Write(bitmask_desc)) self.play(Write(bitmask_title), Write(bitmask_desc))
self.wait(1.5) self.wait(1.5)
# animate bitmask bits flipping
bits_group = VGroup() bits_group = VGroup()
bit_size = 0.18 bit_size = 0.18
for i in range(16): for i in range(16):
square = Square(side_length=bit_size * 2, color=GRAY, square = Square(
fill_opacity=0.0, stroke_width=1.2) side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2,
)
if i in (2, 5, 9, 13): if i in (2, 5, 9, 13):
square.set_fill(GRAY, opacity=0.5) square.set_fill(GRAY, opacity=0.5)
bits_group.add(square) bits_group.add(square)
bits_group.arrange(RIGHT, buff=0.06) bits_group.arrange(RIGHT, buff=0.06)
bits_group.next_to(bitmask_desc, DOWN, buff=0.3) bits_group.next_to(bitmask_desc, DOWN, buff=0.3)
occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4) occupied_lbl = Text("occupied_mask", font_size=11, color=RED) \
.next_to(bits_group, LEFT, buff=0.4)
self.play(Create(bits_group), Write(occupied_lbl)) self.play(Create(bits_group), Write(occupied_lbl))
# flip to ~occupied
flipped = VGroup() flipped = VGroup()
for i, sq in enumerate(bits_group): for i, sq in enumerate(bits_group):
copy_sq = Square(side_length=bit_size * 2, color=GRAY, copy_sq = Square(
fill_opacity=0.0, stroke_width=1.2).move_to(sq) side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2,
).move_to(sq)
if i not in (2, 5, 9, 13): if i not in (2, 5, 9, 13):
copy_sq.set_fill(GRAY, opacity=0.5) copy_sq.set_fill(GRAY, opacity=0.5)
flipped.add(copy_sq) flipped.add(copy_sq)
@ -190,7 +204,6 @@ class ContinuousBatching(Scene):
self.play(FadeOut(bits_group), FadeOut(occupied_lbl), self.play(FadeOut(bits_group), FadeOut(occupied_lbl),
FadeOut(bitmask_title), FadeOut(bitmask_desc)) FadeOut(bitmask_title), FadeOut(bitmask_desc))
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# 9. Gantt timeline comparison — Static vs Continuous # 9. Gantt timeline comparison — Static vs Continuous
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
@ -199,16 +212,15 @@ class ContinuousBatching(Scene):
) )
self.wait(0.2) self.wait(0.2)
# ── layout constants ── CELL = 0.44
CELL = 0.44 # width per time tick BH = 0.32
BH = 0.32 # bar height BGAP = 0.10
BGAP = 0.10 # gap between rows ROW = BH + BGAP
ROW = BH + BGAP # 0.42 — row pitch TICKS = 12
TICKS = 12 # time columns PANEL_W = TICKS * CELL
PANEL_W = TICKS * CELL # 5.28 L_OX = -5.8
L_OX = -5.8 # left-panel origin x R_OX = 1.0
R_OX = 1.0 # right-panel origin x GY = 2.0
GY = 2.0 # gantt top y
def gbox(ox, y, start, span, color, fill=0.75): def gbox(ox, y, start, span, color, fill=0.75):
x = ox + start * CELL x = ox + start * CELL
@ -254,9 +266,10 @@ class ContinuousBatching(Scene):
# ── Left: Static Batching ── # ── Left: Static Batching ──
s_title = Text("Static Batching", font_size=26, color=RED) s_title = Text("Static Batching", font_size=26, color=RED)
s_title.move_to([L_OX + PANEL_W / 2, GY + 0.65, 0]) s_title.move_to([L_OX + PANEL_W / 2, GY + 0.65, 0])
s_note = Text("requests wait → batch together → all run same length · GPU idle gaps", s_note = Text(
font_size=13, color=RED) \ "requests wait → batch together → all run same length · GPU idle gaps",
.move_to([L_OX + PANEL_W / 2, -1.6, 0]) font_size=13, color=RED,
).move_to([L_OX + PANEL_W / 2, -1.6, 0])
self.play(Write(s_title)) self.play(Write(s_title))
self.wait(0.25) self.wait(0.25)
@ -267,7 +280,6 @@ class ContinuousBatching(Scene):
gpu_l.move_to([L_OX - 0.55, GY - ROW, 0]) gpu_l.move_to([L_OX - 0.55, GY - ROW, 0])
self.play(Write(gpu_l)) self.play(Write(gpu_l))
# Static GPU: idle [0-2], batch 1 [2-6], idle [6-8], batch 2 [8-12]
s_y_gpu = GY - ROW s_y_gpu = GY - ROW
s_gpu_idle1 = gbox(L_OX, s_y_gpu, 0, 2, RED, 0.45) s_gpu_idle1 = gbox(L_OX, s_y_gpu, 0, 2, RED, 0.45)
s_gpu_batch1 = gbox(L_OX, s_y_gpu, 2, 4, GREEN) s_gpu_batch1 = gbox(L_OX, s_y_gpu, 2, 4, GREEN)
@ -277,23 +289,19 @@ class ContinuousBatching(Scene):
for seg in s_gpu_bars: for seg in s_gpu_bars:
self.play(GrowFromEdge(seg, LEFT), run_time=0.09) self.play(GrowFromEdge(seg, LEFT), run_time=0.09)
# IDLE labels over the red idle strips
s_idle1 = Text("IDLE", font_size=10, color=RED) \ s_idle1 = Text("IDLE", font_size=10, color=RED) \
.move_to([L_OX + 1 * CELL, s_y_gpu, 0]) .move_to([L_OX + 1 * CELL, s_y_gpu, 0])
s_idle2 = Text("IDLE", font_size=10, color=RED) \ s_idle2 = Text("IDLE", font_size=10, color=RED) \
.move_to([L_OX + 7 * CELL, s_y_gpu, 0]) .move_to([L_OX + 7 * CELL, s_y_gpu, 0])
self.play(Write(s_idle1), Write(s_idle2)) self.play(Write(s_idle1), Write(s_idle2))
# Same 6 requests as continuous — but scheduled in batches
# D, E, F grouped into one batch (gated by F's arrival at t=8)
# (name, color, wait_start, wait_end, run_start, run_end)
s_req_defs = [ s_req_defs = [
("A", ORANGE, 0, 2, 2, 6), # arrives t=0, waits for C → batch 1 ("A", ORANGE, 0, 2, 2, 6),
("B", BLUE, 1, 2, 2, 6), # arrives t=1, waits for C ("B", BLUE, 1, 2, 2, 6),
("C", PINK, 2, 2, 2, 6), # arrives t=2, no wait ("C", PINK, 2, 2, 2, 6),
("D", ORANGE, 4, 8, 8, 12), # arrives t=4, waits for F → batch 2 ("D", ORANGE, 4, 8, 8, 12),
("E", BLUE, 6, 8, 8, 12), # arrives t=6, waits for F ("E", BLUE, 6, 8, 8, 12),
("F", PINK, 8, 8, 8, 12), # arrives t=8, no wait ("F", PINK, 8, 8, 8, 12),
] ]
s_bars = [] s_bars = []
for i, (name, col, ws, we, rs, re) in enumerate(s_req_defs): for i, (name, col, ws, we, rs, re) in enumerate(s_req_defs):
@ -312,9 +320,8 @@ class ContinuousBatching(Scene):
s_bars.extend(items) s_bars.extend(items)
self.play(*anims, run_time=0.09) self.play(*anims, run_time=0.09)
# batch boxes — connect GPU busy segments to the requests they serve s_y_last3 = s_y_gpu - 3 * ROW
s_y_last3 = s_y_gpu - 3 * ROW # Req C is the 3rd request row s_y_last6 = s_y_gpu - 6 * ROW
s_y_last6 = s_y_gpu - 6 * ROW # Req F is the 6th request row
b1_rect, b1_lbl = batch_box(L_OX, s_y_gpu, s_y_last3, 2, 4, RED, "Batch 1") b1_rect, b1_lbl = batch_box(L_OX, s_y_gpu, s_y_last3, 2, 4, RED, "Batch 1")
b2_rect, b2_lbl = batch_box(L_OX, s_y_gpu, s_y_last6, 8, 4, RED, "Batch 2") b2_rect, b2_lbl = batch_box(L_OX, s_y_gpu, s_y_last6, 8, 4, RED, "Batch 2")
self.play(Create(b1_rect), Write(b1_lbl)) self.play(Create(b1_rect), Write(b1_lbl))
@ -324,9 +331,10 @@ class ContinuousBatching(Scene):
# ── Right: Continuous Batching ── # ── Right: Continuous Batching ──
c_title = Text("Continuous Batching", font_size=26, color=GREEN) c_title = Text("Continuous Batching", font_size=26, color=GREEN)
c_title.move_to([R_OX + PANEL_W / 2, GY + 0.65, 0]) c_title.move_to([R_OX + PANEL_W / 2, GY + 0.65, 0])
c_note = Text("no waiting · no padding · GPU never idle", c_note = Text(
font_size=13, color=GREEN) \ "no waiting · no padding · GPU never idle",
.move_to([R_OX + PANEL_W / 2, -1.6, 0]) font_size=13, color=GREEN,
).move_to([R_OX + PANEL_W / 2, -1.6, 0])
self.play(Write(c_title)) self.play(Write(c_title))
self.wait(0.25) self.wait(0.25)
@ -338,11 +346,9 @@ class ContinuousBatching(Scene):
cgpu_l.move_to([R_OX - 0.55, c_y_gpu, 0]) cgpu_l.move_to([R_OX - 0.55, c_y_gpu, 0])
self.play(Write(cgpu_l)) self.play(Write(cgpu_l))
# Continuous GPU: busy all 12 ticks (pipeline never drains)
c_gpu = gbox(R_OX, c_y_gpu, 0, 12, GREEN, 0.75) c_gpu = gbox(R_OX, c_y_gpu, 0, 12, GREEN, 0.75)
self.play(GrowFromEdge(c_gpu, LEFT), run_time=0.5) self.play(GrowFromEdge(c_gpu, LEFT), run_time=0.5)
# Same 6 requests — start immediately, no wait, staggered naturally
c_reqs = [ c_reqs = [
("A", ORANGE, 0, 4), ("A", ORANGE, 0, 4),
("B", BLUE, 1, 4), ("B", BLUE, 1, 4),
@ -362,13 +368,11 @@ class ContinuousBatching(Scene):
self.play(FadeIn(lbl), GrowFromEdge(bar_rect, LEFT), run_time=0.09) self.play(FadeIn(lbl), GrowFromEdge(bar_rect, LEFT), run_time=0.09)
self.wait(0.3) self.wait(0.3)
# continuous box — GPU always serving
c_y_last = c_y_gpu - c_n_reqs * ROW c_y_last = c_y_gpu - c_n_reqs * ROW
c_box_rect, c_box_lbl = batch_box(R_OX, c_y_gpu, c_y_last, 0, 12, GREEN, "Always Serving") c_box_rect, c_box_lbl = batch_box(R_OX, c_y_gpu, c_y_last, 0, 12, GREEN, "Always Serving")
self.play(Create(c_box_rect), Write(c_box_lbl)) self.play(Create(c_box_rect), Write(c_box_lbl))
self.wait(1.0) self.wait(1.0)
# count annotation
s_count = Text("6 reqs · 2 batches · GPU idle gaps", s_count = Text("6 reqs · 2 batches · GPU idle gaps",
font_size=16, color=RED) \ font_size=16, color=RED) \
.next_to(s_gpu_batch1, DOWN, buff=1.0).align_to(s_gpu_batch1, LEFT) .next_to(s_gpu_batch1, DOWN, buff=1.0).align_to(s_gpu_batch1, LEFT)
@ -381,7 +385,6 @@ class ContinuousBatching(Scene):
self.wait(2.5) self.wait(2.5)
self.play(FadeOut(s_count), FadeOut(c_count)) self.play(FadeOut(s_count), FadeOut(c_count))
# ── Fade out gantt ──
gantt_mobs = [ gantt_mobs = [
title, bar, s_title, s_note, c_title, c_note, title, bar, s_title, s_note, c_title, c_note,
gpu_l, cgpu_l, s_idle1, s_idle2, st_axis, ct_axis, gpu_l, cgpu_l, s_idle1, s_idle2, st_axis, ct_axis,
@ -394,22 +397,17 @@ class ContinuousBatching(Scene):
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# 10. Throughput comparison with animated bars # 10. Throughput comparison with animated bars
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# ---- title ----
compare_title = Text("Throughput Comparison", font_size=30, color=BLUE) compare_title = Text("Throughput Comparison", font_size=30, color=BLUE)
self.play(Write(compare_title)) self.play(Write(compare_title))
self.wait(0.2) self.wait(0.2)
self.play(compare_title.animate.to_edge(UP).scale(0.55)) self.play(compare_title.animate.to_edge(UP).scale(0.55))
self.wait(0.2) self.wait(0.2)
# ---- bar config ----
bar_max_w = 5.0 bar_max_w = 5.0
bar_h = 0.55 bar_h = 0.55
row_gap = 0.8 row_gap = 0.8
ratio = 1.0 / 3.4 ratio = 1.0 / 3.4
# ---- Static Batching row ----
s_label = Text("Static Batching", font_size=24, color=RED) s_label = Text("Static Batching", font_size=24, color=RED)
s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5) s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5)
s_bar_rect = Rectangle( s_bar_rect = Rectangle(
@ -418,7 +416,6 @@ class ContinuousBatching(Scene):
) )
s_num = Text("1.0x", font_size=24, color=RED) s_num = Text("1.0x", font_size=24, color=RED)
# ---- Continuous Batching row ----
c_label = Text("Continuous Batching", font_size=24, color=GREEN) c_label = Text("Continuous Batching", font_size=24, color=GREEN)
c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5) c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5)
c_bar_rect = Rectangle( c_bar_rect = Rectangle(
@ -427,13 +424,11 @@ class ContinuousBatching(Scene):
) )
c_num = Text("3.4x", font_size=24, color=GREEN) c_num = Text("3.4x", font_size=24, color=GREEN)
# position rects first, then align bars
s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2)) s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2))
c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2)) c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2))
s_bar_rect.align_to(s_rect, LEFT).align_to(s_rect, UP) s_bar_rect.align_to(s_rect, LEFT).align_to(s_rect, UP)
c_bar_rect.align_to(c_rect, LEFT).align_to(c_rect, UP) c_bar_rect.align_to(c_rect, LEFT).align_to(c_rect, UP)
# labels left, nums right
s_label.next_to(s_rect, LEFT, buff=0.4) s_label.next_to(s_rect, LEFT, buff=0.4)
c_label.next_to(c_rect, LEFT, buff=0.4) c_label.next_to(c_rect, LEFT, buff=0.4)
s_num.next_to(s_rect, RIGHT, buff=0.4) s_num.next_to(s_rect, RIGHT, buff=0.4)
@ -445,13 +440,11 @@ class ContinuousBatching(Scene):
) )
self.wait(0.3) self.wait(0.3)
# grow bars
self.play(GrowFromEdge(s_bar_rect, LEFT), rate_func=linear, run_time=0.6) self.play(GrowFromEdge(s_bar_rect, LEFT), rate_func=linear, run_time=0.6)
self.wait(0.3) self.wait(0.3)
self.play(GrowFromEdge(c_bar_rect, LEFT), rate_func=linear, run_time=0.6) self.play(GrowFromEdge(c_bar_rect, LEFT), rate_func=linear, run_time=0.6)
self.wait(0.3) self.wait(0.3)
# show values
self.play(Write(s_num), Write(c_num)) self.play(Write(s_num), Write(c_num))
self.wait(2.5) self.wait(2.5)