refactor continuous_batching: 4-lane pipeline (PENDING/PREFILL/DECODE/FINISHED), remove bogus Trans arrow, Refill=admission per AstrAI arch

This commit is contained in:
ViperEkura 2026-05-18 15:29:40 +08:00
parent 4f14d09fe3
commit 12d587aa92
1 changed files with 86 additions and 93 deletions

View File

@ -8,8 +8,7 @@ from manim import *
Text.set_default(font="Times New Roman")
# ── palette ──
PHASE_COLORS = {
PAL = {
"Cleanup": GRAY,
"Refill": ORANGE,
"Prefill": BLUE,
@ -30,19 +29,23 @@ class ContinuousBatching(Scene):
bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15)
self.play(Create(bar))
LANE_W, LANE_H = 3.8, 0.95
X_P, X_A, X_F = -4.5, 0.0, 4.5
# ── layout config ──
LANE_W, LANE_H = 2.6, 0.95
X_P, X_Pr, X_D, X_F = -4.95, -1.65, 1.65, 4.95
YL = 0.3
P_CLR, A_CLR, F_CLR = GRAY, BLUE, RED
P_CLR, F_CLR = GRAY, RED
def lane(x, label, clr, sub):
box = RoundedRectangle(width=LANE_W, height=LANE_H, corner_radius=0.12,
color=clr, fill_opacity=0.10, stroke_width=2.2)
box = RoundedRectangle(
width=LANE_W, height=LANE_H, corner_radius=0.12,
color=clr, fill_opacity=0.10, stroke_width=2.2,
)
t = Text(label, font_size=20, color=clr)
s = Text(sub, font_size=10, color=LIGHT_GRAY)
inner = VGroup(t, s).arrange(DOWN, buff=0.04).move_to(box)
return VGroup(box, inner).move_to([x, YL, 0])
# ── FSM label (unchanged) ──
fsm_states = VGroup()
for label, clr in [("Refill", ORANGE), ("", LIGHT_GRAY),
("Prefill", BLUE), ("", LIGHT_GRAY),
@ -53,30 +56,31 @@ class ContinuousBatching(Scene):
fsm_states.arrange(RIGHT, buff=0.06)
fsm_states.next_to(bar, DOWN, buff=0.3)
# ── 4-state pipeline lanes ──
pend_lane = lane(X_P, "PENDING", P_CLR, "waiting queue")
act_lane = lane(X_A, "ACTIVE", A_CLR, "Prefill")
pref_lane = lane(X_Pr, "PREFILL", BLUE, "first token")
dec_lane = lane(X_D, "DECODE", YELLOW, "per-token gen")
fin_lane = lane(X_F, "FINISHED", F_CLR, "sequence done")
lane_group = VGroup(pend_lane, act_lane, fin_lane)
lane_group = VGroup(pend_lane, pref_lane, dec_lane, fin_lane)
# ── arrows ──
ea = Arrow(pend_lane.get_left() + LEFT * 0.9, pend_lane.get_left(),
color=GREEN, stroke_width=2.5,
max_tip_length_to_length_ratio=0.15)
el = Text("New Req", font_size=11, color=GREEN)
el.next_to(ea, UP, buff=0.04)
ca = Arrow(act_lane.get_right(), fin_lane.get_left(),
color=GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.15)
cl = Text("Cleanup", font_size=10, color=GRAY)
cl.next_to(ca, UP, buff=0.04)
ra = Arrow(pend_lane.get_right(), act_lane.get_left(),
ra = Arrow(pend_lane.get_right(), pref_lane.get_left(),
color=ORANGE, buff=0.06,
max_tip_length_to_length_ratio=0.15)
rl = Text("Refill", font_size=10, color=ORANGE)
rl.next_to(ra, UP, buff=0.04)
ca = Arrow(act_lane.get_right(), fin_lane.get_left(),
ta = Arrow(pref_lane.get_right(), dec_lane.get_left(),
color=LIGHT_GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.15)
ca = Arrow(dec_lane.get_right(), fin_lane.get_left(),
color=GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.15)
cl = Text("Cleanup", font_size=10, color=GRAY)
@ -89,10 +93,11 @@ class ContinuousBatching(Scene):
xl.next_to(xa, UP, buff=0.04)
self.play(Write(fsm_states))
self.play(Create(pend_lane), Create(act_lane), Create(fin_lane))
self.play(Create(pend_lane), Create(pref_lane), Create(dec_lane), Create(fin_lane))
self.wait(0.3)
self.play(Create(ea), Write(el),
Create(ra), Write(rl),
Create(ta),
Create(ca), Write(cl),
Create(xa), Write(xl))
self.wait(0.5)
@ -101,8 +106,10 @@ class ContinuousBatching(Scene):
TOK_W, TOK_H = 0.58, 0.38
def mk_tok(name, col, state, n_tok):
card = RoundedRectangle(width=TOK_W, height=TOK_H, corner_radius=0.06,
color=col, fill_opacity=0.38, stroke_width=1.6)
card = RoundedRectangle(
width=TOK_W, height=TOK_H, corner_radius=0.06,
color=col, fill_opacity=0.38, stroke_width=1.6,
)
t = Text(name, font_size=13, color=col).move_to(card)
info = Text(f"{state} {n_tok}t", font_size=7, color=col)
return VGroup(VGroup(card, t), info).arrange(DOWN, buff=0.03)
@ -112,9 +119,10 @@ class ContinuousBatching(Scene):
sx = x - (n - 1) * sp / 2
return [np.array([sx + i * sp, -1.5, 0]) for i in range(n)]
P_SLOTS = slots(X_P, 3)
A_SLOTS = slots(X_A, 3)
F_SLOTS = slots(X_F, 2)
P_SLOTS = slots(X_P, 2) # G, F
Pr_SLOTS = slots(X_Pr, 1) # E
D_SLOTS = slots(X_D, 3) # D, A, B
F_SLOTS = slots(X_F, 1) # C
tok = {}
def add(name, col, lane_slots, idx, state, n):
@ -123,10 +131,10 @@ class ContinuousBatching(Scene):
add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0)
add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0)
add("E", BATCH_COLORS[4], P_SLOTS, 2, "PENDING", 0)
add("D", BATCH_COLORS[3], A_SLOTS, 0, "DECODE", 5)
add("A", BATCH_COLORS[0], A_SLOTS, 1, "DECODE", 9)
add("B", BATCH_COLORS[1], A_SLOTS, 2, "DECODE", 13)
add("E", BATCH_COLORS[4], Pr_SLOTS, 0, "PREFILL", 128)
add("D", BATCH_COLORS[3], D_SLOTS, 0, "DECODE", 5)
add("A", BATCH_COLORS[0], D_SLOTS, 1, "DECODE", 9)
add("B", BATCH_COLORS[1], D_SLOTS, 2, "DECODE", 13)
add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16)
for t in tok.values():
@ -138,12 +146,11 @@ class ContinuousBatching(Scene):
# ═══════════════════════════════════════════════════
# 7. Position-Grouped Decode highlight
# ═══════════════════════════════════════════════════
# show multiple tokens grouped at Decode
ring = SurroundingRectangle(act_lane, color=YELLOW, buff=0.12, stroke_width=3)
ring = SurroundingRectangle(dec_lane, color=YELLOW, buff=0.12, stroke_width=3)
ring_txt = Text(
"Position-Grouped Batching\nSame decode position → single matmul",
font_size=14, color=YELLOW, line_spacing=0.6,
).next_to(act_lane, DOWN, buff=0.5)
).next_to(dec_lane, DOWN, buff=0.5)
self.play(Create(ring), Write(ring_txt))
self.wait(2.0)
self.play(FadeOut(ring), FadeOut(ring_txt))
@ -151,33 +158,40 @@ class ContinuousBatching(Scene):
# ═══════════════════════════════════════════════════
# 8. O(1) Bitmask Slot Allocation
# ═══════════════════════════════════════════════════
bitmask_title = Text("O(1) Slot Allocation via Bitmask",
font_size=22, color=ORANGE).next_to(lane_group, DOWN, buff=0.75)
bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)",
font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15)
bitmask_title = Text(
"O(1) Slot Allocation via Bitmask",
font_size=22, color=ORANGE,
).next_to(lane_group, DOWN, buff=0.75)
bitmask_desc = Text(
"free_slots = ~occupied_mask (one-clock op)",
font_size=15, color=GRAY,
).next_to(bitmask_title, DOWN, buff=0.15)
self.play(Write(bitmask_title), Write(bitmask_desc))
self.wait(1.5)
# animate bitmask bits flipping
bits_group = VGroup()
bit_size = 0.18
for i in range(16):
square = Square(side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2)
square = Square(
side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2,
)
if i in (2, 5, 9, 13):
square.set_fill(GRAY, opacity=0.5)
bits_group.add(square)
bits_group.arrange(RIGHT, buff=0.06)
bits_group.next_to(bitmask_desc, DOWN, buff=0.3)
occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4)
occupied_lbl = Text("occupied_mask", font_size=11, color=RED) \
.next_to(bits_group, LEFT, buff=0.4)
self.play(Create(bits_group), Write(occupied_lbl))
# flip to ~occupied
flipped = VGroup()
for i, sq in enumerate(bits_group):
copy_sq = Square(side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2).move_to(sq)
copy_sq = Square(
side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2,
).move_to(sq)
if i not in (2, 5, 9, 13):
copy_sq.set_fill(GRAY, opacity=0.5)
flipped.add(copy_sq)
@ -190,7 +204,6 @@ class ContinuousBatching(Scene):
self.play(FadeOut(bits_group), FadeOut(occupied_lbl),
FadeOut(bitmask_title), FadeOut(bitmask_desc))
# ═══════════════════════════════════════════════════
# 9. Gantt timeline comparison — Static vs Continuous
# ═══════════════════════════════════════════════════
@ -199,16 +212,15 @@ class ContinuousBatching(Scene):
)
self.wait(0.2)
# ── layout constants ──
CELL = 0.44 # width per time tick
BH = 0.32 # bar height
BGAP = 0.10 # gap between rows
ROW = BH + BGAP # 0.42 — row pitch
TICKS = 12 # time columns
PANEL_W = TICKS * CELL # 5.28
L_OX = -5.8 # left-panel origin x
R_OX = 1.0 # right-panel origin x
GY = 2.0 # gantt top y
CELL = 0.44
BH = 0.32
BGAP = 0.10
ROW = BH + BGAP
TICKS = 12
PANEL_W = TICKS * CELL
L_OX = -5.8
R_OX = 1.0
GY = 2.0
def gbox(ox, y, start, span, color, fill=0.75):
x = ox + start * CELL
@ -254,9 +266,10 @@ class ContinuousBatching(Scene):
# ── Left: Static Batching ──
s_title = Text("Static Batching", font_size=26, color=RED)
s_title.move_to([L_OX + PANEL_W / 2, GY + 0.65, 0])
s_note = Text("requests wait → batch together → all run same length · GPU idle gaps",
font_size=13, color=RED) \
.move_to([L_OX + PANEL_W / 2, -1.6, 0])
s_note = Text(
"requests wait → batch together → all run same length · GPU idle gaps",
font_size=13, color=RED,
).move_to([L_OX + PANEL_W / 2, -1.6, 0])
self.play(Write(s_title))
self.wait(0.25)
@ -267,7 +280,6 @@ class ContinuousBatching(Scene):
gpu_l.move_to([L_OX - 0.55, GY - ROW, 0])
self.play(Write(gpu_l))
# Static GPU: idle [0-2], batch 1 [2-6], idle [6-8], batch 2 [8-12]
s_y_gpu = GY - ROW
s_gpu_idle1 = gbox(L_OX, s_y_gpu, 0, 2, RED, 0.45)
s_gpu_batch1 = gbox(L_OX, s_y_gpu, 2, 4, GREEN)
@ -277,23 +289,19 @@ class ContinuousBatching(Scene):
for seg in s_gpu_bars:
self.play(GrowFromEdge(seg, LEFT), run_time=0.09)
# IDLE labels over the red idle strips
s_idle1 = Text("IDLE", font_size=10, color=RED) \
.move_to([L_OX + 1 * CELL, s_y_gpu, 0])
s_idle2 = Text("IDLE", font_size=10, color=RED) \
.move_to([L_OX + 7 * CELL, s_y_gpu, 0])
self.play(Write(s_idle1), Write(s_idle2))
# Same 6 requests as continuous — but scheduled in batches
# D, E, F grouped into one batch (gated by F's arrival at t=8)
# (name, color, wait_start, wait_end, run_start, run_end)
s_req_defs = [
("A", ORANGE, 0, 2, 2, 6), # arrives t=0, waits for C → batch 1
("B", BLUE, 1, 2, 2, 6), # arrives t=1, waits for C
("C", PINK, 2, 2, 2, 6), # arrives t=2, no wait
("D", ORANGE, 4, 8, 8, 12), # arrives t=4, waits for F → batch 2
("E", BLUE, 6, 8, 8, 12), # arrives t=6, waits for F
("F", PINK, 8, 8, 8, 12), # arrives t=8, no wait
("A", ORANGE, 0, 2, 2, 6),
("B", BLUE, 1, 2, 2, 6),
("C", PINK, 2, 2, 2, 6),
("D", ORANGE, 4, 8, 8, 12),
("E", BLUE, 6, 8, 8, 12),
("F", PINK, 8, 8, 8, 12),
]
s_bars = []
for i, (name, col, ws, we, rs, re) in enumerate(s_req_defs):
@ -312,9 +320,8 @@ class ContinuousBatching(Scene):
s_bars.extend(items)
self.play(*anims, run_time=0.09)
# batch boxes — connect GPU busy segments to the requests they serve
s_y_last3 = s_y_gpu - 3 * ROW # Req C is the 3rd request row
s_y_last6 = s_y_gpu - 6 * ROW # Req F is the 6th request row
s_y_last3 = s_y_gpu - 3 * ROW
s_y_last6 = s_y_gpu - 6 * ROW
b1_rect, b1_lbl = batch_box(L_OX, s_y_gpu, s_y_last3, 2, 4, RED, "Batch 1")
b2_rect, b2_lbl = batch_box(L_OX, s_y_gpu, s_y_last6, 8, 4, RED, "Batch 2")
self.play(Create(b1_rect), Write(b1_lbl))
@ -324,9 +331,10 @@ class ContinuousBatching(Scene):
# ── Right: Continuous Batching ──
c_title = Text("Continuous Batching", font_size=26, color=GREEN)
c_title.move_to([R_OX + PANEL_W / 2, GY + 0.65, 0])
c_note = Text("no waiting · no padding · GPU never idle",
font_size=13, color=GREEN) \
.move_to([R_OX + PANEL_W / 2, -1.6, 0])
c_note = Text(
"no waiting · no padding · GPU never idle",
font_size=13, color=GREEN,
).move_to([R_OX + PANEL_W / 2, -1.6, 0])
self.play(Write(c_title))
self.wait(0.25)
@ -338,11 +346,9 @@ class ContinuousBatching(Scene):
cgpu_l.move_to([R_OX - 0.55, c_y_gpu, 0])
self.play(Write(cgpu_l))
# Continuous GPU: busy all 12 ticks (pipeline never drains)
c_gpu = gbox(R_OX, c_y_gpu, 0, 12, GREEN, 0.75)
self.play(GrowFromEdge(c_gpu, LEFT), run_time=0.5)
# Same 6 requests — start immediately, no wait, staggered naturally
c_reqs = [
("A", ORANGE, 0, 4),
("B", BLUE, 1, 4),
@ -362,13 +368,11 @@ class ContinuousBatching(Scene):
self.play(FadeIn(lbl), GrowFromEdge(bar_rect, LEFT), run_time=0.09)
self.wait(0.3)
# continuous box — GPU always serving
c_y_last = c_y_gpu - c_n_reqs * ROW
c_box_rect, c_box_lbl = batch_box(R_OX, c_y_gpu, c_y_last, 0, 12, GREEN, "Always Serving")
self.play(Create(c_box_rect), Write(c_box_lbl))
self.wait(1.0)
# count annotation
s_count = Text("6 reqs · 2 batches · GPU idle gaps",
font_size=16, color=RED) \
.next_to(s_gpu_batch1, DOWN, buff=1.0).align_to(s_gpu_batch1, LEFT)
@ -381,7 +385,6 @@ class ContinuousBatching(Scene):
self.wait(2.5)
self.play(FadeOut(s_count), FadeOut(c_count))
# ── Fade out gantt ──
gantt_mobs = [
title, bar, s_title, s_note, c_title, c_note,
gpu_l, cgpu_l, s_idle1, s_idle2, st_axis, ct_axis,
@ -394,22 +397,17 @@ class ContinuousBatching(Scene):
# ═══════════════════════════════════════════════════
# 10. Throughput comparison with animated bars
# ═══════════════════════════════════════════════════
# ---- title ----
compare_title = Text("Throughput Comparison", font_size=30, color=BLUE)
self.play(Write(compare_title))
self.wait(0.2)
self.play(compare_title.animate.to_edge(UP).scale(0.55))
self.wait(0.2)
# ---- bar config ----
bar_max_w = 5.0
bar_h = 0.55
row_gap = 0.8
ratio = 1.0 / 3.4
# ---- Static Batching row ----
s_label = Text("Static Batching", font_size=24, color=RED)
s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5)
s_bar_rect = Rectangle(
@ -418,7 +416,6 @@ class ContinuousBatching(Scene):
)
s_num = Text("1.0x", font_size=24, color=RED)
# ---- Continuous Batching row ----
c_label = Text("Continuous Batching", font_size=24, color=GREEN)
c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5)
c_bar_rect = Rectangle(
@ -427,13 +424,11 @@ class ContinuousBatching(Scene):
)
c_num = Text("3.4x", font_size=24, color=GREEN)
# position rects first, then align bars
s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2))
c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2))
s_bar_rect.align_to(s_rect, LEFT).align_to(s_rect, UP)
c_bar_rect.align_to(c_rect, LEFT).align_to(c_rect, UP)
# labels left, nums right
s_label.next_to(s_rect, LEFT, buff=0.4)
c_label.next_to(c_rect, LEFT, buff=0.4)
s_num.next_to(s_rect, RIGHT, buff=0.4)
@ -445,13 +440,11 @@ class ContinuousBatching(Scene):
)
self.wait(0.3)
# grow bars
self.play(GrowFromEdge(s_bar_rect, LEFT), rate_func=linear, run_time=0.6)
self.wait(0.3)
self.play(GrowFromEdge(c_bar_rect, LEFT), rate_func=linear, run_time=0.6)
self.wait(0.3)
# show values
self.play(Write(s_num), Write(c_num))
self.wait(2.5)