fix transformer: GQA text overflow, heatmap sizing, auto-regressive pos labels

- Shrink GQA title (42→34) to fit screen
- Move GQA annotation from left overflow to right-bottom of V box
- Enlarge heatmap cells (0.52→0.65) and labels (12→14, 9→10), lift grid up
- Remove Repeat KV section (shorten scene ~2s)
- Add position labels to auto-regressive token sequence
- Add layer stack effect behind transformer block
- Upgrade font sizes and spacing throughout for readability
This commit is contained in:
ViperEkura 2026-05-25 19:19:48 +08:00
parent d471cfa276
commit 9de0bad3d4
1 changed files with 53 additions and 50 deletions

View File

@ -15,12 +15,12 @@ class Transformer(Scene):
"""Animates the GQA attention mechanism with orthogonal connection lines."""
def construct(self):
title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE)
title = Text("Grouped-Query Attention (GQA)", font_size=34, color=BLUE)
title.to_edge(UP, buff=0.35)
self.play(Write(title))
# ── Helper: box ──
def mk(name, color, w=2.6, h=0.72, fs=10):
def mk(name, color, w=3.0, h=0.85, fs=13):
box = Rectangle(
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
)
@ -28,32 +28,32 @@ class Transformer(Scene):
return VGroup(box, lbl)
# ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp = Text("x (hidden states)", font_size=20, color=GRAY)
inp.move_to(UP * 2.5)
y1 = 1.6
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
q_grp.move_to(LEFT * 3.0 + UP * y1)
q_grp.move_to(LEFT * 3.6 + UP * y1)
k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1)
v_grp.move_to(RIGHT * 3.6 + UP * y1)
y2 = 0.4
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.8, 0.80)
repeat_grp.move_to(UP * y2)
y3 = -1.0
sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 3.2, 0.85,
)
sdpa_grp.move_to(UP * y3)
y4 = -2.2
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.6, 0.80)
o_grp.move_to(UP * y4)
out = Text("x' (hidden states)", font_size=15, color=GRAY)
out = Text("x' (hidden states)", font_size=20, color=GRAY)
out.next_to(o_grp, DOWN, buff=0.4)
# ── Animate boxes ──
@ -169,31 +169,18 @@ class Transformer(Scene):
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
)
gqa_t = Text(
"GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%",
font_size=13, color=YELLOW,
"GQA 6:1\n24 Q-heads → 4 KV-heads\nKV cache -83%",
font_size=11, color=YELLOW,
)
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
gqa_t.next_to(v_grp, DOWN, buff=0.4).shift(RIGHT * 0.4)
self.play(Create(gqa_h), Write(gqa_t))
self.wait(1.8)
self.play(FadeOut(gqa_h), FadeOut(gqa_t))
# ── Repeat KV highlight ──
kv_h = SurroundingRectangle(
VGroup(k_grp, v_grp), color=GREEN, buff=0.12
)
kv_t = Text(
"repeat_kv(): broadcast\n4 heads → 24 heads",
font_size=12, color=GREEN,
)
kv_t.next_to(kv_h, RIGHT, buff=0.5)
self.play(Create(kv_h), Write(kv_t))
self.wait(1.5)
# ── Fade all ──
self.play(
*[FadeOut(g) for g in all_boxes],
FadeOut(all_lines),
FadeOut(kv_h), FadeOut(kv_t),
FadeOut(inp), FadeOut(out), FadeOut(title),
)
@ -268,11 +255,11 @@ class Transformer(Scene):
tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
n = len(tokens)
cell_size = 0.52
gap = 0.04
cell_size = 0.65
gap = 0.05
grid_high = n * cell_size + (n - 1) * gap
grid_left = -grid_high / 2
grid_top = 1.4
grid_top = 1.7
# pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased
pre_scores = [
@ -321,7 +308,7 @@ class Transformer(Scene):
# row labels (query) on the left
row_lbls = VGroup()
for i, tok in enumerate(tokens):
lbl = Text(tok, font_size=12, color=GRAY)
lbl = Text(tok, font_size=14, color=GRAY)
y = grid_top - i * (cell_size + gap) - cell_size / 2
lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08)
row_lbls.add(lbl)
@ -332,7 +319,7 @@ class Transformer(Scene):
# column labels (key) on top
col_lbls = VGroup()
for j, tok in enumerate(tokens):
lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6)
lbl = Text(tok, font_size=10, color=GRAY).rotate(PI / 6)
x = grid_left + j * (cell_size + gap) + cell_size / 2
lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04)
col_lbls.add(lbl)
@ -379,10 +366,10 @@ class Transformer(Scene):
# ═══════════════════════════════════════════════════
# Auto-regressive Generation Demo (v2: full I/O pipeline)
def tok_card(text, fill=DARK_BLUE, stroke=GRAY):
t = Text(text, font_size=10, color=WHITE)
t = Text(text, font_size=14, color=WHITE)
box = RoundedRectangle(
width=t.width + 0.2, height=t.height + 0.1,
corner_radius=0.04, fill_color=fill, fill_opacity=0.5,
width=t.width + 0.3, height=t.height + 0.16,
corner_radius=0.06, fill_color=fill, fill_opacity=0.5,
stroke_color=stroke, stroke_width=0.5,
)
t.move_to(box)
@ -399,11 +386,10 @@ class Transformer(Scene):
BLK_H = 0.28
TFR_H = 0.55
CX_BLK = -1.0
Y_EMB = 1.40
Y_NORM1 = 0.95
Y_TFR = 0.32
Y_NORM2 = -0.25
Y_HEAD = -0.70
Y_EMB = 1.95
Y_TFR = 1.10
Y_NORM2 = 0.35
Y_HEAD = -0.15
def mkblk(w, h, y, txt, color, fs=10):
b = RoundedRectangle(width=w, height=h, corner_radius=0.06,
@ -413,11 +399,19 @@ class Transformer(Scene):
return VGroup(b, l).move_to([CX_BLK, y, 0])
emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9)
norm1 = mkblk(BLK_W, BLK_H, Y_NORM1, "RMS Norm", GREEN, 9)
tfr_node = mkblk(2.4, TFR_H, Y_TFR, "Transformer Block\n× 24", PURPLE, 9)
tfr_node = mkblk(2.4, TFR_H, Y_TFR, "Transformer Block\n× 24", PURPLE, 12)
norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9)
head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9)
pipeline = VGroup(emb_node, norm1, tfr_node, norm2, head)
pipeline = VGroup(emb_node, tfr_node, norm2, head)
# Stack effect for ×24 layers behind Transformer block
layer_stack = VGroup()
for i in range(3, 0, -1):
shadow = RoundedRectangle(
width=2.4, height=TFR_H, corner_radius=0.06,
stroke_color=BLUE_D, stroke_width=1.0, fill_opacity=0,
).move_to([CX_BLK + i * 0.04, Y_TFR - i * 0.04, 0])
layer_stack.add(shadow)
# Arrows between blocks
arrows = VGroup()
@ -446,9 +440,9 @@ class Transformer(Scene):
kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v)
# ── Distribution builder ──
def build_dist(probs, y_center, max_w=2.8):
def build_dist(probs, y_center, max_w=3.0):
bars = VGroup(); lbls = VGroup()
bh = 0.12; bg = 0.02; lx = CX_BLK - max_w / 2
bh = 0.18; bg = 0.03; lx = CX_BLK - max_w / 2
items = list(probs.items())
n = len(items)
y_top = y_center + (n * bh + (n - 1) * bg) / 2
@ -459,7 +453,7 @@ class Transformer(Scene):
fill_color=interpolate_color(BLUE, RED, pct / 100),
fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3)
bar.move_to([lx + w / 2, y, 0])
lbl = Text(f"{tok} {pct}%", font_size=7, color=WHITE)
lbl = Text(f"{tok} {pct}%", font_size=10, color=WHITE)
lbl.next_to(bar, RIGHT, buff=0.05)
bars.add(bar); lbls.add(lbl)
return VGroup(bars, lbls)
@ -472,22 +466,28 @@ class Transformer(Scene):
{"the": 60, "a": 12, "top": 8, "floor": 5, "<unk>": 15},
{"mat": 50, "table": 15, "chair": 8, "floor": 6, "<unk>": 21},
]
Y_DIST = -1.45
Y_DIST = -1.65
# ── Token sequence row ──
SX = -4.0; Y_SEQ = 2.0; GAP = 0.55
SX = -4.0; Y_SEQ = 2.4; GAP = 0.70
seq = VGroup()
sos = tok_card("<s>", BLUE, BLUE).move_to([SX, Y_SEQ, 0])
seq.add(sos)
# Position labels under each token
pos_lbls = VGroup()
pos0 = Text("0", font_size=9, color=DARK_GRAY).next_to(sos, DOWN, buff=0.08)
pos_lbls.add(pos0)
# Step label (below everything)
step_lbl = Text("Step 0 — [<s>] → ?", font_size=9, color=GRAY)
step_lbl.move_to([0, -2.1, 0])
# ── Show static elements ──
self.play(FadeIn(pipeline), FadeIn(arrows))
self.play(FadeIn(layer_stack))
self.play(FadeIn(kv_group))
self.play(FadeIn(sos))
self.play(FadeIn(sos), Write(pos0))
self.play(Write(step_lbl))
self.wait(0.5)
@ -506,13 +506,12 @@ class Transformer(Scene):
self.play(FadeIn(in_arr, scale=0.5), run_time=0.1)
# 3. Cascade through pipeline
pipes = [emb_node, norm1, tfr_node, norm2, head]
self.play(
*[p[0].animate.set_fill_opacity(0.6) for p in pipes],
*[p[0].animate.set_fill_opacity(0.6) for p in pipeline],
run_time=0.12
)
self.play(
*[p[0].animate.set_fill_opacity(0.25) for p in pipes],
*[p[0].animate.set_fill_opacity(0.25) for p in pipeline],
run_time=0.1
)
self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08)
@ -539,12 +538,16 @@ class Transformer(Scene):
# 7. Token rises to join sequence at top
self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3)
pos_lbl = Text(str(i), font_size=9, color=DARK_GRAY)
pos_lbl.next_to(pred, DOWN, buff=0.08)
self.play(
pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5),
pred[1].animate.set_color(WHITE),
Write(pos_lbl),
FadeOut(sample_hl), FadeOut(samp_lbl),
)
seq.add(pred)
pos_lbls.add(pos_lbl)
# 8. KV Cache: add K,V of this predicted token
kv_x = KV_X + i * (kv_size + kv_gap)