refactor: redesign batching FSM as queue pipeline with dynamic task states

- Replace 4 vertical system-phase boxes with 3 horizontal lanes
  (PENDING queue / RUNNING batch / FINISHED done) for accurate
  request lifecycle per scheduler.py:197-200
- System phases (Refill, Prefill, Decode, Cleanup) shown as
  transition labels between lanes
- Tokens placed below lanes with dynamic state badge + cumulative
  token count, updated each tick via ReplacementTransform
- Fix prefix_cache collective FadeOut using self.mobjects sweep
- Remove weight=BOLD across all scenes to prevent text drift
- Adjust GQA y-coordinates for subtitle clearance
This commit is contained in:
ViperEkura 2026-05-07 17:56:17 +08:00
parent c05a432e45
commit 6b0a1dbb5e
4 changed files with 218 additions and 146 deletions

View File

@ -19,7 +19,7 @@ class Architecture(Scene):
def make_box(header, color, bits, src):
b = Rectangle(width=W, height=BH, color=color, fill_opacity=0.1, stroke_width=1.5)
h = Text(header, font_size=16, color=color, weight=BOLD)
h = Text(header, font_size=16, color=color)
items = [h]
for line in bits:
items.append(Text(line, font_size=10, color=WHITE))
@ -80,7 +80,7 @@ class Architecture(Scene):
]
def make_desc(lines, color):
els = [Text(lines[0], font_size=20, color=color, weight=BOLD)]
els = [Text(lines[0], font_size=20, color=color)]
for ln in lines[1:]:
els.append(Text(ln, font_size=14, color=WHITE))
grp = VGroup(*els).arrange(DOWN, buff=0.1, aligned_edge=LEFT)

View File

@ -31,165 +31,242 @@ class ContinuousBatching(Scene):
self.play(Create(bar))
# ═══════════════════════════════════════════════════
# 1. Build state-machine layout (vertical, 4 states)
# 1. Queue-style pipeline (PENDING—queue / RUNNING—batch / FINISHED—done)
# System phases: Refill, Prefill, Decode, Cleanup
# ═══════════════════════════════════════════════════
state_names = ["Cleanup", "Refill", "Prefill", "Decode"]
LANE_W, LANE_H = 3.8, 0.95
X_P, X_R, X_F = -4.5, 0.0, 4.5 # lane x-centers
P_CLR, R_CLR, F_CLR = GRAY, BLUE, RED
states = VGroup()
trans_arrows = VGroup()
for i, name in enumerate(state_names):
box = RoundedRectangle(
width=3.6, height=0.8, corner_radius=0.15,
color=PHASE_COLORS[name], fill_opacity=0.12, stroke_width=2.5,
)
lbl = Text(name, font_size=20, color=PHASE_COLORS[name])
states.add(VGroup(box, lbl))
def build_lane(x, label, clr, subtitle):
lane = RoundedRectangle(width=LANE_W, height=LANE_H, corner_radius=0.12,
color=clr, fill_opacity=0.10, stroke_width=2.2)
t = Text(label, font_size=20, color=clr)
sub = Text(subtitle, font_size=10, color=LIGHT_GRAY)
inner = VGroup(t, sub).arrange(DOWN, buff=0.04)
inner.move_to(lane.get_center())
grp = VGroup(lane, inner).move_to([x, 0.3, 0])
return grp
states.arrange(DOWN, buff=0.3)
states.shift(LEFT * 3.8 + DOWN * 0.5)
pend_lane = build_lane(X_P, "PENDING", P_CLR, "waiting queue")
run_lane = build_lane(X_R, "RUNNING", R_CLR, "active batch")
fin_lane = build_lane(X_F, "FINISHED", F_CLR, "sequence done")
req_group = VGroup(pend_lane, run_lane, fin_lane)
for i in range(1, 4):
a = Arrow(
states[i - 1].get_bottom(), states[i].get_top(),
color=LIGHT_GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.22,
)
trans_arrows.add(a)
for i in range(4):
self.play(Create(states[i]))
if i > 0:
self.play(Create(trans_arrows[i - 1]))
# loop arrow — Decode returns to Cleanup (multiturn decoding)
loop = CurvedArrow(
states[-1].get_right() + RIGHT * 0.2,
states[0].get_right() + RIGHT * 0.2,
color=LIGHT_GRAY, angle=PI / 2,
# ── transition arrows (system phase labels) ──
ref_arrow = Arrow(
pend_lane.get_right(), run_lane.get_left(),
color=ORANGE, buff=0.06,
max_tip_length_to_length_ratio=0.15,
)
loop_lbl = Text("per token", font_size=11, color=GRAY).next_to(loop, RIGHT, buff=0.08)
self.play(Create(loop), Write(loop_lbl))
ref_lbl = Text("Refill PENDING→RUNNING", font_size=10, color=ORANGE)
ref_lbl.next_to(ref_arrow, UP, buff=0.04)
cln_arrow = Arrow(
run_lane.get_right(), fin_lane.get_left(),
color=GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.15,
)
cln_lbl = Text("Cleanup RUNNING→FINISHED", font_size=10, color=GRAY)
cln_lbl.next_to(cln_arrow, UP, buff=0.04)
dec_arrow = CurvedArrow(
run_lane.get_top() + UP * 0.3 + LEFT * 0.35,
run_lane.get_top() + UP * 0.3 + RIGHT * 0.35,
color=YELLOW, angle=PI,
)
dec_lbl = Text("Decode per token", font_size=10, color=YELLOW)
dec_lbl.next_to(dec_arrow, UP, buff=0.04)
pre_arrow = Arrow(
run_lane.get_left() + LEFT * 0.5 + DOWN * 0.5,
run_lane.get_left() + LEFT * 0.05,
color=BLUE, stroke_width=1.5,
max_tip_length_to_length_ratio=0.14,
)
pre_lbl = Text("Prefill once on first entry", font_size=10, color=BLUE)
pre_lbl.next_to(pre_arrow, DOWN, buff=0.04)
entry_arrow = Arrow(
pend_lane.get_left() + LEFT * 0.9,
pend_lane.get_left(),
color=GREEN, stroke_width=2.5,
max_tip_length_to_length_ratio=0.15,
)
entry_lbl = Text("New Req", font_size=11, color=GREEN)
entry_lbl.next_to(entry_arrow, UP, buff=0.04)
exit_arrow = Arrow(
fin_lane.get_right(),
fin_lane.get_right() + RIGHT * 0.9,
color=RED, stroke_width=2.5,
max_tip_length_to_length_ratio=0.15,
)
exit_lbl = Text("Exit", font_size=11, color=RED)
exit_lbl.next_to(exit_arrow, UP, buff=0.04)
# ── show lanes ──
self.play(Create(pend_lane))
self.wait(0.25)
self.play(Create(run_lane))
self.wait(0.25)
self.play(Create(fin_lane))
self.wait(0.3)
self.play(
Create(entry_arrow), Write(entry_lbl),
Create(ref_arrow), Write(ref_lbl),
Create(dec_arrow), Write(dec_lbl),
Create(pre_arrow), Write(pre_lbl),
Create(cln_arrow), Write(cln_lbl),
Create(exit_arrow), Write(exit_lbl),
)
self.wait(1.0)
# ═══════════════════════════════════════════════════
# 2. Token flow demo (dynamic state + sequence length)
# ═══════════════════════════════════════════════════
TOK_W, TOK_H = 0.58, 0.38
def mk_tok(name, col, state, n_tok):
card = RoundedRectangle(width=TOK_W, height=TOK_H, corner_radius=0.06,
color=col, fill_opacity=0.38, stroke_width=1.6)
t = Text(name, font_size=13, color=col).move_to(card)
info = Text(f"{state} {n_tok}t", font_size=7, color=col)
return VGroup(VGroup(card, t), info).arrange(DOWN, buff=0.03)
# slot positions inside each lane
def slots(x, n):
sp = LANE_W * 0.72 / max(n, 1)
sx = x - (n - 1) * sp / 2
return [np.array([sx + i * sp, -1.5, 0]) for i in range(n)]
P_SLOTS = slots(X_P, 3)
R_SLOTS = slots(X_R, 3)
F_SLOTS = slots(X_F, 2)
# initial tokens distributed across the 3 lanes
tok = {}
def add(name, col, lane_slots, idx, state, n):
t = mk_tok(name, col, state, n).move_to(lane_slots[idx])
tok[name] = t
return t
add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0)
add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0)
add("E", BATCH_COLORS[4], P_SLOTS, 2, "PENDING", 0)
add("D", BATCH_COLORS[3], R_SLOTS, 0, "RUNNING", 5)
add("A", BATCH_COLORS[0], R_SLOTS, 1, "RUNNING", 9)
add("B", BATCH_COLORS[1], R_SLOTS, 2, "RUNNING", 13)
add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16)
for t in tok.values():
self.play(FadeIn(t, scale=0.7), run_time=0.18)
self.wait(0.4)
# ═══════════════════════════════════════════════════
# 2. Boot tokens — initial batches placed at mid-cycle
# ═══════════════════════════════════════════════════
def make_token(name: str, col: str) -> VGroup:
card = RoundedRectangle(width=0.65, height=0.38, corner_radius=0.08,
color=col, fill_opacity=0.35, stroke_width=1.8)
txt = Text(name, font_size=13, color=col)
return VGroup(card, txt)
tokens = {
"A": make_token("A", BATCH_COLORS[0]),
"B": make_token("B", BATCH_COLORS[1]),
"C": make_token("C", BATCH_COLORS[2]),
}
# all three at consecutive stages, Prefill is the entry point
tokens["A"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill
tokens["B"].move_to(states[3]).shift(RIGHT * 1.5) # Decode
tokens["C"].move_to(states[0]).shift(RIGHT * 1.5) # Cleanup
for t in tokens.values():
self.play(FadeIn(t, scale=0.7), run_time=0.25)
self.wait(0.2)
note = Text("Every request starts at Prefill", font_size=16, color=WHITE) \
.next_to(states, DOWN, buff=0.55)
note = Text(
"Refill: PENDING→RUNNING · Prefill: process prompt (once) · Decode: count += 1 each step · Cleanup: RUNNING→FINISHED→exit",
font_size=12, color=WHITE,
)
note.next_to(req_group, DOWN, buff=0.8)
self.play(Write(note))
self.wait(1.0)
self.wait(2.5)
self.play(FadeOut(note))
# ═══════════════════════════════════════════════════
# 3. Tick 1 — advance, C exits, new D enters at Prefill
# ═══════════════════════════════════════════════════
slots = [
states[0].get_center() + RIGHT * 1.5, # Cleanup
states[1].get_center() + RIGHT * 1.5, # Refill
states[2].get_center() + RIGHT * 1.5, # Prefill
states[3].get_center() + RIGHT * 1.5, # Decode
]
# ── Tick 1: C exits, B→FINISHED, E→RUNNING (Refill+Prefill), all RUNNING count+1 ──
# decode: D, A, B each +1
d2 = mk_tok("D", BATCH_COLORS[3], "RUNNING", 6).move_to(R_SLOTS[0])
a2 = mk_tok("A", BATCH_COLORS[0], "RUNNING", 10).move_to(R_SLOTS[1])
b_fin = mk_tok("B", BATCH_COLORS[1], "FINISHED", 14).move_to(F_SLOTS[1])
e_run = mk_tok("E", BATCH_COLORS[4], "RUNNING", 4).move_to(R_SLOTS[2])
h_pen = mk_tok("H", BATCH_COLORS[7], "PENDING", 0).move_to(P_SLOTS[2])
self.play(
tokens["A"].animate.move_to(slots[3]), # Prefill → Decode
tokens["B"].animate.move_to(slots[0]), # Decode → Cleanup
tokens["C"].animate.move_to(slots[1]), # Cleanup → Refill
ReplacementTransform(tok["D"], d2),
ReplacementTransform(tok["A"], a2),
ReplacementTransform(tok["B"], b_fin),
ReplacementTransform(tok["E"], e_run),
FadeOut(tok["C"], scale=0.5),
FadeIn(h_pen, scale=0.7),
)
tok.update({"D": d2, "A": a2, "B": b_fin, "E": e_run, "H": h_pen})
del tok["C"]
self.wait(0.3)
# C (now at Refill) exits after completing the loop
# new D enters at Prefill
self.play(FadeOut(tokens["C"], scale=0.6))
tokens["D"] = make_token("D", BATCH_COLORS[3])
tokens["D"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry
self.play(FadeIn(tokens["D"], scale=0.7))
self.wait(0.25)
# ── Tick 2: B exits, A→FINISHED, F→RUNNING, D count+1, E count+1 ──
d3 = mk_tok("D", BATCH_COLORS[3], "RUNNING", 7).move_to(R_SLOTS[0])
a_fin = mk_tok("A", BATCH_COLORS[0], "FINISHED", 11).move_to(F_SLOTS[0])
e3 = mk_tok("E", BATCH_COLORS[4], "RUNNING", 5).move_to(R_SLOTS[1])
f_run = mk_tok("F", BATCH_COLORS[5], "RUNNING", 4).move_to(R_SLOTS[2])
i_pen = mk_tok("I", BATCH_COLORS[0], "PENDING", 0).move_to(P_SLOTS[2])
# ═══════════════════════════════════════════════════
# 4. Tick 2 — advance, B exits, new E enters at Prefill
# ═══════════════════════════════════════════════════
self.play(
tokens["D"].animate.move_to(slots[3]), # Prefill → Decode
tokens["A"].animate.move_to(slots[0]), # Decode → Cleanup
tokens["B"].animate.move_to(slots[1]), # Cleanup → Refill
ReplacementTransform(tok["D"], d3),
ReplacementTransform(tok["A"], a_fin),
ReplacementTransform(tok["E"], e3),
ReplacementTransform(tok["F"], f_run),
FadeOut(tok["B"], scale=0.5),
FadeIn(i_pen, scale=0.7),
)
tok.update({"D": d3, "A": a_fin, "E": e3, "F": f_run, "I": i_pen})
del tok["B"]
self.wait(0.3)
self.play(FadeOut(tokens["B"], scale=0.6))
tokens["E"] = make_token("E", BATCH_COLORS[4])
tokens["E"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry
self.play(FadeIn(tokens["E"], scale=0.7))
self.wait(0.25)
# ── Tick 3: A exits, D→FINISHED, G→RUNNING, E count+1, F count+1 ──
d_fin = mk_tok("D", BATCH_COLORS[3], "FINISHED", 8).move_to(F_SLOTS[1])
e4 = mk_tok("E", BATCH_COLORS[4], "RUNNING", 6).move_to(R_SLOTS[0])
f3 = mk_tok("F", BATCH_COLORS[5], "RUNNING", 5).move_to(R_SLOTS[1])
g_run = mk_tok("G", BATCH_COLORS[6], "RUNNING", 4).move_to(R_SLOTS[2])
j_pen = mk_tok("J", BATCH_COLORS[1], "PENDING", 0).move_to(P_SLOTS[2])
# ═══════════════════════════════════════════════════
# 5. Tick 3 — advance, A exits, new F enters at Prefill
# ═══════════════════════════════════════════════════
self.play(
tokens["E"].animate.move_to(slots[3]), # Prefill → Decode
tokens["D"].animate.move_to(slots[0]), # Decode → Cleanup
tokens["A"].animate.move_to(slots[1]), # Cleanup → Refill
ReplacementTransform(tok["D"], d_fin),
ReplacementTransform(tok["E"], e4),
ReplacementTransform(tok["F"], f3),
ReplacementTransform(tok["G"], g_run),
FadeOut(tok["A"], scale=0.5),
FadeIn(j_pen, scale=0.7),
)
self.wait(0.25)
tok.update({"D": d_fin, "E": e4, "F": f3, "G": g_run, "J": j_pen})
del tok["A"]
self.wait(0.3)
self.play(FadeOut(tokens["A"], scale=0.6))
tokens["F"] = make_token("F", BATCH_COLORS[5])
tokens["F"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry
self.play(FadeIn(tokens["F"], scale=0.7))
self.wait(0.25)
# ── Tick 4: D exits, E→FINISHED, H→RUNNING, F count+1, G count+1 ──
e_fin = mk_tok("E", BATCH_COLORS[4], "FINISHED", 7).move_to(F_SLOTS[0])
f4 = mk_tok("F", BATCH_COLORS[5], "RUNNING", 6).move_to(R_SLOTS[0])
g3 = mk_tok("G", BATCH_COLORS[6], "RUNNING", 5).move_to(R_SLOTS[1])
h_run = mk_tok("H", BATCH_COLORS[7], "RUNNING", 4).move_to(R_SLOTS[2])
k_pen = mk_tok("K", BATCH_COLORS[2], "PENDING", 0).move_to(P_SLOTS[2])
# ═══════════════════════════════════════════════════
# 6. Tick 4 — advance, F exits, new G enters at Prefill
# ═══════════════════════════════════════════════════
self.play(
tokens["F"].animate.move_to(slots[3]), # Prefill → Decode
tokens["E"].animate.move_to(slots[0]), # Decode → Cleanup
tokens["D"].animate.move_to(slots[1]), # Cleanup → Refill
ReplacementTransform(tok["E"], e_fin),
ReplacementTransform(tok["F"], f4),
ReplacementTransform(tok["G"], g3),
ReplacementTransform(tok["H"], h_run),
FadeOut(tok["D"], scale=0.5),
FadeIn(k_pen, scale=0.7),
)
self.wait(0.25)
tok.update({"E": e_fin, "F": f4, "G": g3, "H": h_run, "K": k_pen})
del tok["D"]
self.wait(0.4)
self.play(FadeOut(tokens["D"], scale=0.6))
tokens["G"] = make_token("G", BATCH_COLORS[6])
tokens["G"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill ← entry
self.play(FadeIn(tokens["G"], scale=0.7))
self.wait(0.35)
# drop note: constant throughput, all enter at Prefill
flow_note = Text("All requests enter at Prefill — pipeline never drains",
font_size=15, color=GREEN).next_to(states, DOWN, buff=0.55)
flow_note = Text("Pipeline never drains · all 3 states active at once",
font_size=14, color=GREEN)
flow_note.next_to(req_group, DOWN, buff=0.8)
self.play(Write(flow_note))
self.wait(1.5)
self.play(FadeOut(flow_note))
# clear tokens
self.play(*[FadeOut(t) for t in tokens.values()])
self.play(*[FadeOut(t) for t in tok.values()])
# ═══════════════════════════════════════════════════
# 7. Position-Grouped Decode highlight
# ═══════════════════════════════════════════════════
# show multiple tokens grouped at Decode
d_pos = states[3].get_center()
d_pos = run_lane.get_center() # RUNNING = Decode state
d_tokens = [
make_token("T" + str(i), BATCH_COLORS[i]) for i in range(4)
mk_tok("T" + str(i), BATCH_COLORS[i], "RUNNING", 5 + i * 2) for i in range(4)
]
positions = [
d_pos + RIGHT * 1.2 + UP * 0.45,
@ -201,11 +278,11 @@ class ContinuousBatching(Scene):
d_tokens[i].move_to(positions[i])
self.play(FadeIn(d_tokens[i], scale=0.6), run_time=0.2)
ring = SurroundingRectangle(states[3], color=YELLOW, buff=0.12, stroke_width=3)
ring = SurroundingRectangle(run_lane, color=YELLOW, buff=0.12, stroke_width=3)
ring_txt = Text(
"Position-Grouped Batching\nSame decode position → single matmul",
font_size=14, color=YELLOW, line_spacing=0.6,
).next_to(states[3], DOWN, buff=0.5)
).next_to(run_lane, DOWN, buff=0.5)
self.play(Create(ring), Write(ring_txt))
self.wait(2.0)
self.play(FadeOut(ring), FadeOut(ring_txt),
@ -215,7 +292,7 @@ class ContinuousBatching(Scene):
# 8. O(1) Bitmask Slot Allocation
# ═══════════════════════════════════════════════════
bitmask_title = Text("O(1) Slot Allocation via Bitmask",
font_size=22, color=ORANGE).next_to(states, DOWN, buff=0.75)
font_size=22, color=ORANGE).next_to(req_group, DOWN, buff=0.75)
bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)",
font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15)
self.play(Write(bitmask_title), Write(bitmask_desc))
@ -259,12 +336,7 @@ class ContinuousBatching(Scene):
# ═══════════════════════════════════════════════════
self.play(
*[FadeOut(m) for m in self.mobjects if m is not title and m is not bar],
FadeOut(loop), FadeOut(loop_lbl),
)
for s in states:
self.play(FadeOut(s), run_time=0.10)
for a in trans_arrows:
self.play(FadeOut(a), run_time=0.10)
self.wait(0.2)
# ── layout constants ──
@ -346,9 +418,9 @@ class ContinuousBatching(Scene):
self.play(GrowFromEdge(seg, LEFT), run_time=0.09)
# IDLE labels over the red idle strips
s_idle1 = Text("IDLE", font_size=10, color=RED, weight=BOLD) \
s_idle1 = Text("IDLE", font_size=10, color=RED) \
.move_to([L_OX + 1 * CELL, s_y_gpu, 0])
s_idle2 = Text("IDLE", font_size=10, color=RED, weight=BOLD) \
s_idle2 = Text("IDLE", font_size=10, color=RED) \
.move_to([L_OX + 11 * CELL, s_y_gpu, 0])
self.play(Write(s_idle1), Write(s_idle2))
@ -436,10 +508,10 @@ class ContinuousBatching(Scene):
# count annotation
s_count = Text("5 reqs · 2 batches · GPU idle gaps",
font_size=16, color=RED, weight=BOLD) \
font_size=16, color=RED) \
.next_to(s_gpu_batch1, DOWN, buff=1.0).align_to(s_gpu_batch1, LEFT)
c_count = Text("5 reqs · continuous · GPU never idle",
font_size=16, color=GREEN, weight=BOLD) \
font_size=16, color=GREEN) \
.next_to(c_gpu, DOWN, buff=1.0).align_to(c_gpu, LEFT)
self.play(Write(s_note), Write(c_note))
self.wait(0.3)

View File

@ -116,4 +116,4 @@ class PrefixCache(Scene):
summary.to_edge(DOWN, buff=0.5)
self.play(Write(summary))
self.wait(2)
self.play(FadeOut(summary), FadeOut(root_grp), FadeOut(title))
self.play(*[FadeOut(m) for m in self.mobjects])

View File

@ -29,9 +29,9 @@ class Transformer(Scene):
# ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp.move_to(UP * 2.6)
inp.move_to(UP * 2.5)
y1 = 1.5
y1 = 1.6
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
@ -39,17 +39,17 @@ class Transformer(Scene):
k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1)
y2 = 0.0
y2 = 0.4
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp.move_to(UP * y2)
y3 = -1.6
y3 = -1.0
sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
)
sdpa_grp.move_to(UP * y3)
y4 = -2.9
y4 = -2.2
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp.move_to(UP * y4)
@ -63,7 +63,7 @@ class Transformer(Scene):
self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2)
# ── Input trunk → branch → Q/K/V (enter from directly above) ──
trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0])
trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.2, 0])
trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5)
self.play(Create(trunk), run_time=0.15)
@ -325,7 +325,7 @@ class Transformer(Scene):
y = grid_top - i * (cell_size + gap) - cell_size / 2
lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08)
row_lbls.add(lbl)
q_label = Text("Q", font_size=11, color=WHITE, weight=BOLD)
q_label = Text("Q", font_size=11, color=WHITE)
q_label.move_to(row_lbls[0].get_left() + LEFT * 0.3).shift(UP * 0.15)
self.play(*[Write(l) for l in row_lbls], Write(q_label))
@ -336,7 +336,7 @@ class Transformer(Scene):
x = grid_left + j * (cell_size + gap) + cell_size / 2
lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04)
col_lbls.add(lbl)
k_label = Text("K", font_size=11, color=WHITE, weight=BOLD)
k_label = Text("K", font_size=11, color=WHITE)
k_label.next_to(col_lbls[0], UP, buff=0.06)
self.play(*[Write(l) for l in col_lbls], Write(k_label))
self.wait(1.0)
@ -379,7 +379,7 @@ class Transformer(Scene):
# ═══════════════════════════════════════════════════
# Auto-regressive Generation Demo (v2: full I/O pipeline)
def tok_card(text, fill=DARK_BLUE, stroke=GRAY):
t = Text(text, font_size=10, color=WHITE, weight=BOLD)
t = Text(text, font_size=10, color=WHITE)
box = RoundedRectangle(
width=t.width + 0.2, height=t.height + 0.1,
corner_radius=0.04, fill_color=fill, fill_opacity=0.5,
@ -409,7 +409,7 @@ class Transformer(Scene):
b = RoundedRectangle(width=w, height=h, corner_radius=0.06,
fill_color=DARK_BLUE, fill_opacity=0.25,
stroke_color=color, stroke_width=1.5)
l = Text(txt, font_size=fs, color=color, weight=BOLD).move_to(b)
l = Text(txt, font_size=fs, color=color).move_to(b)
return VGroup(b, l).move_to([CX_BLK, y, 0])
emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9)