redesign continuous batching: simplify to lane-based FSM with Prefill state
- Remove FSM 4-state cycle, tick animations, Refill/Prefill/Decode arrows - Show PENDING/ACTIVE/FINISHED lanes with Refill->Cleanup flow - Add FSM state row (Refill->Prefill->Decode->Cleanup) - ACTIVE lane shows single Prefill state label - architecture: drop '(1B params)' from title
This commit is contained in:
parent
c9f290c3c8
commit
bb0c32b032
|
|
@ -72,7 +72,7 @@ class Architecture(Scene):
|
|||
"with page-table-indirected access.",
|
||||
"Per-task page tables map logical pages",
|
||||
"to physical pages — O(1) alloc/free."],
|
||||
["Transformer Model (1B params)",
|
||||
["Transformer Model",
|
||||
"Decoder-only Transformer with",
|
||||
"Grouped-Query Attention (GQA 6:1).",
|
||||
"RoPE rotary encoding, SwiGLU",
|
||||
|
|
|
|||
|
|
@ -30,101 +30,74 @@ class ContinuousBatching(Scene):
|
|||
bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15)
|
||||
self.play(Create(bar))
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# 1. Queue-style pipeline (PENDING—queue / RUNNING—batch / FINISHED—done)
|
||||
# System phases: Refill, Prefill, Decode, Cleanup
|
||||
# ═══════════════════════════════════════════════════
|
||||
LANE_W, LANE_H = 3.8, 0.95
|
||||
X_P, X_R, X_F = -4.5, 0.0, 4.5 # lane x-centers
|
||||
P_CLR, R_CLR, F_CLR = GRAY, BLUE, RED
|
||||
X_P, X_A, X_F = -4.5, 0.0, 4.5
|
||||
YL = 0.3
|
||||
P_CLR, A_CLR, F_CLR = GRAY, BLUE, RED
|
||||
|
||||
def build_lane(x, label, clr, subtitle):
|
||||
lane = RoundedRectangle(width=LANE_W, height=LANE_H, corner_radius=0.12,
|
||||
def lane(x, label, clr, sub):
|
||||
box = RoundedRectangle(width=LANE_W, height=LANE_H, corner_radius=0.12,
|
||||
color=clr, fill_opacity=0.10, stroke_width=2.2)
|
||||
t = Text(label, font_size=20, color=clr)
|
||||
sub = Text(subtitle, font_size=10, color=LIGHT_GRAY)
|
||||
inner = VGroup(t, sub).arrange(DOWN, buff=0.04)
|
||||
inner.move_to(lane.get_center())
|
||||
grp = VGroup(lane, inner).move_to([x, 0.3, 0])
|
||||
return grp
|
||||
s = Text(sub, font_size=10, color=LIGHT_GRAY)
|
||||
inner = VGroup(t, s).arrange(DOWN, buff=0.04).move_to(box)
|
||||
return VGroup(box, inner).move_to([x, YL, 0])
|
||||
|
||||
pend_lane = build_lane(X_P, "PENDING", P_CLR, "waiting queue")
|
||||
run_lane = build_lane(X_R, "RUNNING", R_CLR, "active batch")
|
||||
fin_lane = build_lane(X_F, "FINISHED", F_CLR, "sequence done")
|
||||
req_group = VGroup(pend_lane, run_lane, fin_lane)
|
||||
fsm_states = VGroup()
|
||||
for label, clr in [("Refill", ORANGE), ("→", LIGHT_GRAY),
|
||||
("Prefill", BLUE), ("→", LIGHT_GRAY),
|
||||
("Decode", YELLOW), ("→", LIGHT_GRAY),
|
||||
("Cleanup", GRAY)]:
|
||||
t = Text(label, font_size=13, color=clr)
|
||||
fsm_states.add(t)
|
||||
fsm_states.arrange(RIGHT, buff=0.06)
|
||||
fsm_states.next_to(bar, DOWN, buff=0.3)
|
||||
|
||||
# ── transition arrows (system phase labels) ──
|
||||
ref_arrow = Arrow(
|
||||
pend_lane.get_right(), run_lane.get_left(),
|
||||
color=ORANGE, buff=0.06,
|
||||
max_tip_length_to_length_ratio=0.15,
|
||||
)
|
||||
ref_lbl = Text("Refill PENDING→RUNNING", font_size=10, color=ORANGE)
|
||||
ref_lbl.next_to(ref_arrow, UP, buff=0.04)
|
||||
pend_lane = lane(X_P, "PENDING", P_CLR, "waiting queue")
|
||||
act_lane = lane(X_A, "ACTIVE", A_CLR, "Prefill")
|
||||
fin_lane = lane(X_F, "FINISHED", F_CLR, "sequence done")
|
||||
lane_group = VGroup(pend_lane, act_lane, fin_lane)
|
||||
|
||||
cln_arrow = Arrow(
|
||||
run_lane.get_right(), fin_lane.get_left(),
|
||||
color=GRAY, buff=0.06,
|
||||
max_tip_length_to_length_ratio=0.15,
|
||||
)
|
||||
cln_lbl = Text("Cleanup RUNNING→FINISHED", font_size=10, color=GRAY)
|
||||
cln_lbl.next_to(cln_arrow, UP, buff=0.04)
|
||||
|
||||
dec_arrow = CurvedArrow(
|
||||
run_lane.get_top() + UP * 0.3 + LEFT * 0.35,
|
||||
run_lane.get_top() + UP * 0.3 + RIGHT * 0.35,
|
||||
color=YELLOW, angle=PI,
|
||||
)
|
||||
dec_lbl = Text("Decode per token", font_size=10, color=YELLOW)
|
||||
dec_lbl.next_to(dec_arrow, UP, buff=0.04)
|
||||
|
||||
pre_arrow = Arrow(
|
||||
run_lane.get_left() + LEFT * 0.5 + DOWN * 0.5,
|
||||
run_lane.get_left() + LEFT * 0.05,
|
||||
color=BLUE, stroke_width=1.5,
|
||||
max_tip_length_to_length_ratio=0.14,
|
||||
)
|
||||
pre_lbl = Text("Prefill once on first entry", font_size=10, color=BLUE)
|
||||
pre_lbl.next_to(pre_arrow, DOWN, buff=0.04)
|
||||
|
||||
entry_arrow = Arrow(
|
||||
pend_lane.get_left() + LEFT * 0.9,
|
||||
pend_lane.get_left(),
|
||||
ea = Arrow(pend_lane.get_left() + LEFT * 0.9, pend_lane.get_left(),
|
||||
color=GREEN, stroke_width=2.5,
|
||||
max_tip_length_to_length_ratio=0.15,
|
||||
)
|
||||
entry_lbl = Text("New Req", font_size=11, color=GREEN)
|
||||
entry_lbl.next_to(entry_arrow, UP, buff=0.04)
|
||||
max_tip_length_to_length_ratio=0.15)
|
||||
el = Text("New Req", font_size=11, color=GREEN)
|
||||
el.next_to(ea, UP, buff=0.04)
|
||||
|
||||
exit_arrow = Arrow(
|
||||
fin_lane.get_right(),
|
||||
fin_lane.get_right() + RIGHT * 0.9,
|
||||
ca = Arrow(act_lane.get_right(), fin_lane.get_left(),
|
||||
color=GRAY, buff=0.06,
|
||||
max_tip_length_to_length_ratio=0.15)
|
||||
cl = Text("Cleanup", font_size=10, color=GRAY)
|
||||
cl.next_to(ca, UP, buff=0.04)
|
||||
|
||||
ra = Arrow(pend_lane.get_right(), act_lane.get_left(),
|
||||
color=ORANGE, buff=0.06,
|
||||
max_tip_length_to_length_ratio=0.15)
|
||||
rl = Text("Refill", font_size=10, color=ORANGE)
|
||||
rl.next_to(ra, UP, buff=0.04)
|
||||
|
||||
ca = Arrow(act_lane.get_right(), fin_lane.get_left(),
|
||||
color=GRAY, buff=0.06,
|
||||
max_tip_length_to_length_ratio=0.15)
|
||||
cl = Text("Cleanup", font_size=10, color=GRAY)
|
||||
cl.next_to(ca, UP, buff=0.04)
|
||||
|
||||
xa = Arrow(fin_lane.get_right(), fin_lane.get_right() + RIGHT * 0.9,
|
||||
color=RED, stroke_width=2.5,
|
||||
max_tip_length_to_length_ratio=0.15,
|
||||
)
|
||||
exit_lbl = Text("Exit", font_size=11, color=RED)
|
||||
exit_lbl.next_to(exit_arrow, UP, buff=0.04)
|
||||
max_tip_length_to_length_ratio=0.15)
|
||||
xl = Text("Exit", font_size=11, color=RED)
|
||||
xl.next_to(xa, UP, buff=0.04)
|
||||
|
||||
# ── show lanes ──
|
||||
self.play(Create(pend_lane))
|
||||
self.wait(0.25)
|
||||
self.play(Create(run_lane))
|
||||
self.wait(0.25)
|
||||
self.play(Create(fin_lane))
|
||||
self.play(Write(fsm_states))
|
||||
self.play(Create(pend_lane), Create(act_lane), Create(fin_lane))
|
||||
self.wait(0.3)
|
||||
self.play(
|
||||
Create(entry_arrow), Write(entry_lbl),
|
||||
Create(ref_arrow), Write(ref_lbl),
|
||||
Create(dec_arrow), Write(dec_lbl),
|
||||
Create(pre_arrow), Write(pre_lbl),
|
||||
Create(cln_arrow), Write(cln_lbl),
|
||||
Create(exit_arrow), Write(exit_lbl),
|
||||
)
|
||||
self.wait(1.0)
|
||||
self.play(Create(ea), Write(el),
|
||||
Create(ra), Write(rl),
|
||||
Create(ca), Write(cl),
|
||||
Create(xa), Write(xl))
|
||||
self.wait(0.5)
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# 2. Token flow demo (dynamic state + sequence length)
|
||||
# ═══════════════════════════════════════════════════
|
||||
# ── Tokens ──
|
||||
TOK_W, TOK_H = 0.58, 0.38
|
||||
|
||||
def mk_tok(name, col, state, n_tok):
|
||||
|
|
@ -134,129 +107,31 @@ class ContinuousBatching(Scene):
|
|||
info = Text(f"{state} {n_tok}t", font_size=7, color=col)
|
||||
return VGroup(VGroup(card, t), info).arrange(DOWN, buff=0.03)
|
||||
|
||||
# slot positions inside each lane
|
||||
def slots(x, n):
|
||||
sp = LANE_W * 0.72 / max(n, 1)
|
||||
sx = x - (n - 1) * sp / 2
|
||||
return [np.array([sx + i * sp, -1.5, 0]) for i in range(n)]
|
||||
|
||||
P_SLOTS = slots(X_P, 3)
|
||||
R_SLOTS = slots(X_R, 3)
|
||||
A_SLOTS = slots(X_A, 3)
|
||||
F_SLOTS = slots(X_F, 2)
|
||||
|
||||
# initial tokens distributed across the 3 lanes
|
||||
tok = {}
|
||||
def add(name, col, lane_slots, idx, state, n):
|
||||
t = mk_tok(name, col, state, n).move_to(lane_slots[idx])
|
||||
tok[name] = t
|
||||
return t
|
||||
|
||||
add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0)
|
||||
add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0)
|
||||
add("E", BATCH_COLORS[4], P_SLOTS, 2, "PENDING", 0)
|
||||
|
||||
add("D", BATCH_COLORS[3], R_SLOTS, 0, "RUNNING", 5)
|
||||
add("A", BATCH_COLORS[0], R_SLOTS, 1, "RUNNING", 9)
|
||||
add("B", BATCH_COLORS[1], R_SLOTS, 2, "RUNNING", 13)
|
||||
|
||||
add("D", BATCH_COLORS[3], A_SLOTS, 0, "DECODE", 5)
|
||||
add("A", BATCH_COLORS[0], A_SLOTS, 1, "DECODE", 9)
|
||||
add("B", BATCH_COLORS[1], A_SLOTS, 2, "DECODE", 13)
|
||||
add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16)
|
||||
|
||||
for t in tok.values():
|
||||
self.play(FadeIn(t, scale=0.7), run_time=0.18)
|
||||
self.wait(0.4)
|
||||
|
||||
note = Text(
|
||||
"Refill: PENDING→RUNNING · Prefill: process prompt (once) · Decode: count += 1 each step · Cleanup: RUNNING→FINISHED→exit",
|
||||
font_size=12, color=WHITE,
|
||||
)
|
||||
note.next_to(req_group, DOWN, buff=0.8)
|
||||
self.play(Write(note))
|
||||
self.wait(2.5)
|
||||
self.play(FadeOut(note))
|
||||
|
||||
# ── Tick 1: C exits, B→FINISHED, E→RUNNING (Refill+Prefill), all RUNNING count+1 ──
|
||||
# decode: D, A, B each +1
|
||||
d2 = mk_tok("D", BATCH_COLORS[3], "RUNNING", 6).move_to(R_SLOTS[0])
|
||||
a2 = mk_tok("A", BATCH_COLORS[0], "RUNNING", 10).move_to(R_SLOTS[1])
|
||||
b_fin = mk_tok("B", BATCH_COLORS[1], "FINISHED", 14).move_to(F_SLOTS[1])
|
||||
e_run = mk_tok("E", BATCH_COLORS[4], "RUNNING", 4).move_to(R_SLOTS[2])
|
||||
h_pen = mk_tok("H", BATCH_COLORS[7], "PENDING", 0).move_to(P_SLOTS[2])
|
||||
|
||||
self.play(
|
||||
ReplacementTransform(tok["D"], d2),
|
||||
ReplacementTransform(tok["A"], a2),
|
||||
ReplacementTransform(tok["B"], b_fin),
|
||||
ReplacementTransform(tok["E"], e_run),
|
||||
FadeOut(tok["C"], scale=0.5),
|
||||
FadeIn(h_pen, scale=0.7),
|
||||
)
|
||||
tok.update({"D": d2, "A": a2, "B": b_fin, "E": e_run, "H": h_pen})
|
||||
del tok["C"]
|
||||
self.wait(0.3)
|
||||
|
||||
# ── Tick 2: B exits, A→FINISHED, F→RUNNING, D count+1, E count+1 ──
|
||||
d3 = mk_tok("D", BATCH_COLORS[3], "RUNNING", 7).move_to(R_SLOTS[0])
|
||||
a_fin = mk_tok("A", BATCH_COLORS[0], "FINISHED", 11).move_to(F_SLOTS[0])
|
||||
e3 = mk_tok("E", BATCH_COLORS[4], "RUNNING", 5).move_to(R_SLOTS[1])
|
||||
f_run = mk_tok("F", BATCH_COLORS[5], "RUNNING", 4).move_to(R_SLOTS[2])
|
||||
i_pen = mk_tok("I", BATCH_COLORS[0], "PENDING", 0).move_to(P_SLOTS[2])
|
||||
|
||||
self.play(
|
||||
ReplacementTransform(tok["D"], d3),
|
||||
ReplacementTransform(tok["A"], a_fin),
|
||||
ReplacementTransform(tok["E"], e3),
|
||||
ReplacementTransform(tok["F"], f_run),
|
||||
FadeOut(tok["B"], scale=0.5),
|
||||
FadeIn(i_pen, scale=0.7),
|
||||
)
|
||||
tok.update({"D": d3, "A": a_fin, "E": e3, "F": f_run, "I": i_pen})
|
||||
del tok["B"]
|
||||
self.wait(0.3)
|
||||
|
||||
# ── Tick 3: A exits, D→FINISHED, G→RUNNING, E count+1, F count+1 ──
|
||||
d_fin = mk_tok("D", BATCH_COLORS[3], "FINISHED", 8).move_to(F_SLOTS[1])
|
||||
e4 = mk_tok("E", BATCH_COLORS[4], "RUNNING", 6).move_to(R_SLOTS[0])
|
||||
f3 = mk_tok("F", BATCH_COLORS[5], "RUNNING", 5).move_to(R_SLOTS[1])
|
||||
g_run = mk_tok("G", BATCH_COLORS[6], "RUNNING", 4).move_to(R_SLOTS[2])
|
||||
j_pen = mk_tok("J", BATCH_COLORS[1], "PENDING", 0).move_to(P_SLOTS[2])
|
||||
|
||||
self.play(
|
||||
ReplacementTransform(tok["D"], d_fin),
|
||||
ReplacementTransform(tok["E"], e4),
|
||||
ReplacementTransform(tok["F"], f3),
|
||||
ReplacementTransform(tok["G"], g_run),
|
||||
FadeOut(tok["A"], scale=0.5),
|
||||
FadeIn(j_pen, scale=0.7),
|
||||
)
|
||||
tok.update({"D": d_fin, "E": e4, "F": f3, "G": g_run, "J": j_pen})
|
||||
del tok["A"]
|
||||
self.wait(0.3)
|
||||
|
||||
# ── Tick 4: D exits, E→FINISHED, H→RUNNING, F count+1, G count+1 ──
|
||||
e_fin = mk_tok("E", BATCH_COLORS[4], "FINISHED", 7).move_to(F_SLOTS[0])
|
||||
f4 = mk_tok("F", BATCH_COLORS[5], "RUNNING", 6).move_to(R_SLOTS[0])
|
||||
g3 = mk_tok("G", BATCH_COLORS[6], "RUNNING", 5).move_to(R_SLOTS[1])
|
||||
h_run = mk_tok("H", BATCH_COLORS[7], "RUNNING", 4).move_to(R_SLOTS[2])
|
||||
k_pen = mk_tok("K", BATCH_COLORS[2], "PENDING", 0).move_to(P_SLOTS[2])
|
||||
|
||||
self.play(
|
||||
ReplacementTransform(tok["E"], e_fin),
|
||||
ReplacementTransform(tok["F"], f4),
|
||||
ReplacementTransform(tok["G"], g3),
|
||||
ReplacementTransform(tok["H"], h_run),
|
||||
FadeOut(tok["D"], scale=0.5),
|
||||
FadeIn(k_pen, scale=0.7),
|
||||
)
|
||||
tok.update({"E": e_fin, "F": f4, "G": g3, "H": h_run, "K": k_pen})
|
||||
del tok["D"]
|
||||
self.wait(0.4)
|
||||
|
||||
flow_note = Text("Pipeline never drains · all 3 states active at once",
|
||||
font_size=14, color=GREEN)
|
||||
flow_note.next_to(req_group, DOWN, buff=0.8)
|
||||
self.play(Write(flow_note))
|
||||
self.wait(1.5)
|
||||
self.play(FadeOut(flow_note))
|
||||
self.wait(2.0)
|
||||
|
||||
self.play(*[FadeOut(t) for t in tok.values()])
|
||||
|
||||
|
|
@ -264,11 +139,11 @@ class ContinuousBatching(Scene):
|
|||
# 7. Position-Grouped Decode highlight
|
||||
# ═══════════════════════════════════════════════════
|
||||
# show multiple tokens grouped at Decode
|
||||
ring = SurroundingRectangle(run_lane, color=YELLOW, buff=0.12, stroke_width=3)
|
||||
ring = SurroundingRectangle(act_lane, color=YELLOW, buff=0.12, stroke_width=3)
|
||||
ring_txt = Text(
|
||||
"Position-Grouped Batching\nSame decode position → single matmul",
|
||||
font_size=14, color=YELLOW, line_spacing=0.6,
|
||||
).next_to(run_lane, DOWN, buff=0.5)
|
||||
).next_to(act_lane, DOWN, buff=0.5)
|
||||
self.play(Create(ring), Write(ring_txt))
|
||||
self.wait(2.0)
|
||||
self.play(FadeOut(ring), FadeOut(ring_txt))
|
||||
|
|
@ -277,7 +152,7 @@ class ContinuousBatching(Scene):
|
|||
# 8. O(1) Bitmask Slot Allocation
|
||||
# ═══════════════════════════════════════════════════
|
||||
bitmask_title = Text("O(1) Slot Allocation via Bitmask",
|
||||
font_size=22, color=ORANGE).next_to(req_group, DOWN, buff=0.75)
|
||||
font_size=22, color=ORANGE).next_to(lane_group, DOWN, buff=0.75)
|
||||
bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)",
|
||||
font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15)
|
||||
self.play(Write(bitmask_title), Write(bitmask_desc))
|
||||
|
|
|
|||
Loading…
Reference in New Issue