fix transformer: GQA text overflow, heatmap sizing, auto-regressive pos labels
- Shrink GQA title (42→34) to fit screen - Move GQA annotation from left overflow to right-bottom of V box - Enlarge heatmap cells (0.52→0.65) and labels (12→14, 9→10), lift grid up - Remove Repeat KV section (shorten scene ~2s) - Add position labels to auto-regressive token sequence - Add layer stack effect behind transformer block - Upgrade font sizes and spacing throughout for readability
This commit is contained in:
parent
d471cfa276
commit
9de0bad3d4
103
transformer.py
103
transformer.py
|
|
@ -15,12 +15,12 @@ class Transformer(Scene):
|
||||||
"""Animates the GQA attention mechanism with orthogonal connection lines."""
|
"""Animates the GQA attention mechanism with orthogonal connection lines."""
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE)
|
title = Text("Grouped-Query Attention (GQA)", font_size=34, color=BLUE)
|
||||||
title.to_edge(UP, buff=0.35)
|
title.to_edge(UP, buff=0.35)
|
||||||
self.play(Write(title))
|
self.play(Write(title))
|
||||||
|
|
||||||
# ── Helper: box ──
|
# ── Helper: box ──
|
||||||
def mk(name, color, w=2.6, h=0.72, fs=10):
|
def mk(name, color, w=3.0, h=0.85, fs=13):
|
||||||
box = Rectangle(
|
box = Rectangle(
|
||||||
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
|
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
|
||||||
)
|
)
|
||||||
|
|
@ -28,32 +28,32 @@ class Transformer(Scene):
|
||||||
return VGroup(box, lbl)
|
return VGroup(box, lbl)
|
||||||
|
|
||||||
# ── Layout ──
|
# ── Layout ──
|
||||||
inp = Text("x (hidden states)", font_size=15, color=GRAY)
|
inp = Text("x (hidden states)", font_size=20, color=GRAY)
|
||||||
inp.move_to(UP * 2.5)
|
inp.move_to(UP * 2.5)
|
||||||
|
|
||||||
y1 = 1.6
|
y1 = 1.6
|
||||||
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
|
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
|
||||||
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
|
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
|
||||||
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
|
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
|
||||||
q_grp.move_to(LEFT * 3.0 + UP * y1)
|
q_grp.move_to(LEFT * 3.6 + UP * y1)
|
||||||
k_grp.move_to(UP * y1)
|
k_grp.move_to(UP * y1)
|
||||||
v_grp.move_to(RIGHT * 3.0 + UP * y1)
|
v_grp.move_to(RIGHT * 3.6 + UP * y1)
|
||||||
|
|
||||||
y2 = 0.4
|
y2 = 0.4
|
||||||
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
|
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.8, 0.80)
|
||||||
repeat_grp.move_to(UP * y2)
|
repeat_grp.move_to(UP * y2)
|
||||||
|
|
||||||
y3 = -1.0
|
y3 = -1.0
|
||||||
sdpa_grp = mk(
|
sdpa_grp = mk(
|
||||||
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
|
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 3.2, 0.85,
|
||||||
)
|
)
|
||||||
sdpa_grp.move_to(UP * y3)
|
sdpa_grp.move_to(UP * y3)
|
||||||
|
|
||||||
y4 = -2.2
|
y4 = -2.2
|
||||||
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
|
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.6, 0.80)
|
||||||
o_grp.move_to(UP * y4)
|
o_grp.move_to(UP * y4)
|
||||||
|
|
||||||
out = Text("x' (hidden states)", font_size=15, color=GRAY)
|
out = Text("x' (hidden states)", font_size=20, color=GRAY)
|
||||||
out.next_to(o_grp, DOWN, buff=0.4)
|
out.next_to(o_grp, DOWN, buff=0.4)
|
||||||
|
|
||||||
# ── Animate boxes ──
|
# ── Animate boxes ──
|
||||||
|
|
@ -169,31 +169,18 @@ class Transformer(Scene):
|
||||||
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
|
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
|
||||||
)
|
)
|
||||||
gqa_t = Text(
|
gqa_t = Text(
|
||||||
"GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%",
|
"GQA 6:1\n24 Q-heads → 4 KV-heads\nKV cache -83%",
|
||||||
font_size=13, color=YELLOW,
|
font_size=11, color=YELLOW,
|
||||||
)
|
)
|
||||||
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
|
gqa_t.next_to(v_grp, DOWN, buff=0.4).shift(RIGHT * 0.4)
|
||||||
self.play(Create(gqa_h), Write(gqa_t))
|
self.play(Create(gqa_h), Write(gqa_t))
|
||||||
self.wait(1.8)
|
self.wait(1.8)
|
||||||
self.play(FadeOut(gqa_h), FadeOut(gqa_t))
|
self.play(FadeOut(gqa_h), FadeOut(gqa_t))
|
||||||
|
|
||||||
# ── Repeat KV highlight ──
|
|
||||||
kv_h = SurroundingRectangle(
|
|
||||||
VGroup(k_grp, v_grp), color=GREEN, buff=0.12
|
|
||||||
)
|
|
||||||
kv_t = Text(
|
|
||||||
"repeat_kv(): broadcast\n4 heads → 24 heads",
|
|
||||||
font_size=12, color=GREEN,
|
|
||||||
)
|
|
||||||
kv_t.next_to(kv_h, RIGHT, buff=0.5)
|
|
||||||
self.play(Create(kv_h), Write(kv_t))
|
|
||||||
self.wait(1.5)
|
|
||||||
|
|
||||||
# ── Fade all ──
|
# ── Fade all ──
|
||||||
self.play(
|
self.play(
|
||||||
*[FadeOut(g) for g in all_boxes],
|
*[FadeOut(g) for g in all_boxes],
|
||||||
FadeOut(all_lines),
|
FadeOut(all_lines),
|
||||||
FadeOut(kv_h), FadeOut(kv_t),
|
|
||||||
FadeOut(inp), FadeOut(out), FadeOut(title),
|
FadeOut(inp), FadeOut(out), FadeOut(title),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -268,11 +255,11 @@ class Transformer(Scene):
|
||||||
|
|
||||||
tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
|
tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
|
||||||
n = len(tokens)
|
n = len(tokens)
|
||||||
cell_size = 0.52
|
cell_size = 0.65
|
||||||
gap = 0.04
|
gap = 0.05
|
||||||
grid_high = n * cell_size + (n - 1) * gap
|
grid_high = n * cell_size + (n - 1) * gap
|
||||||
grid_left = -grid_high / 2
|
grid_left = -grid_high / 2
|
||||||
grid_top = 1.4
|
grid_top = 1.7
|
||||||
|
|
||||||
# pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased
|
# pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased
|
||||||
pre_scores = [
|
pre_scores = [
|
||||||
|
|
@ -321,7 +308,7 @@ class Transformer(Scene):
|
||||||
# row labels (query) on the left
|
# row labels (query) on the left
|
||||||
row_lbls = VGroup()
|
row_lbls = VGroup()
|
||||||
for i, tok in enumerate(tokens):
|
for i, tok in enumerate(tokens):
|
||||||
lbl = Text(tok, font_size=12, color=GRAY)
|
lbl = Text(tok, font_size=14, color=GRAY)
|
||||||
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
||||||
lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08)
|
lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08)
|
||||||
row_lbls.add(lbl)
|
row_lbls.add(lbl)
|
||||||
|
|
@ -332,7 +319,7 @@ class Transformer(Scene):
|
||||||
# column labels (key) on top
|
# column labels (key) on top
|
||||||
col_lbls = VGroup()
|
col_lbls = VGroup()
|
||||||
for j, tok in enumerate(tokens):
|
for j, tok in enumerate(tokens):
|
||||||
lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6)
|
lbl = Text(tok, font_size=10, color=GRAY).rotate(PI / 6)
|
||||||
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
||||||
lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04)
|
lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04)
|
||||||
col_lbls.add(lbl)
|
col_lbls.add(lbl)
|
||||||
|
|
@ -379,10 +366,10 @@ class Transformer(Scene):
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
# Auto-regressive Generation Demo (v2: full I/O pipeline)
|
# Auto-regressive Generation Demo (v2: full I/O pipeline)
|
||||||
def tok_card(text, fill=DARK_BLUE, stroke=GRAY):
|
def tok_card(text, fill=DARK_BLUE, stroke=GRAY):
|
||||||
t = Text(text, font_size=10, color=WHITE)
|
t = Text(text, font_size=14, color=WHITE)
|
||||||
box = RoundedRectangle(
|
box = RoundedRectangle(
|
||||||
width=t.width + 0.2, height=t.height + 0.1,
|
width=t.width + 0.3, height=t.height + 0.16,
|
||||||
corner_radius=0.04, fill_color=fill, fill_opacity=0.5,
|
corner_radius=0.06, fill_color=fill, fill_opacity=0.5,
|
||||||
stroke_color=stroke, stroke_width=0.5,
|
stroke_color=stroke, stroke_width=0.5,
|
||||||
)
|
)
|
||||||
t.move_to(box)
|
t.move_to(box)
|
||||||
|
|
@ -399,11 +386,10 @@ class Transformer(Scene):
|
||||||
BLK_H = 0.28
|
BLK_H = 0.28
|
||||||
TFR_H = 0.55
|
TFR_H = 0.55
|
||||||
CX_BLK = -1.0
|
CX_BLK = -1.0
|
||||||
Y_EMB = 1.40
|
Y_EMB = 1.95
|
||||||
Y_NORM1 = 0.95
|
Y_TFR = 1.10
|
||||||
Y_TFR = 0.32
|
Y_NORM2 = 0.35
|
||||||
Y_NORM2 = -0.25
|
Y_HEAD = -0.15
|
||||||
Y_HEAD = -0.70
|
|
||||||
|
|
||||||
def mkblk(w, h, y, txt, color, fs=10):
|
def mkblk(w, h, y, txt, color, fs=10):
|
||||||
b = RoundedRectangle(width=w, height=h, corner_radius=0.06,
|
b = RoundedRectangle(width=w, height=h, corner_radius=0.06,
|
||||||
|
|
@ -413,11 +399,19 @@ class Transformer(Scene):
|
||||||
return VGroup(b, l).move_to([CX_BLK, y, 0])
|
return VGroup(b, l).move_to([CX_BLK, y, 0])
|
||||||
|
|
||||||
emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9)
|
emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9)
|
||||||
norm1 = mkblk(BLK_W, BLK_H, Y_NORM1, "RMS Norm", GREEN, 9)
|
tfr_node = mkblk(2.4, TFR_H, Y_TFR, "Transformer Block\n× 24", PURPLE, 12)
|
||||||
tfr_node = mkblk(2.4, TFR_H, Y_TFR, "Transformer Block\n× 24", PURPLE, 9)
|
|
||||||
norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9)
|
norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9)
|
||||||
head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9)
|
head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9)
|
||||||
pipeline = VGroup(emb_node, norm1, tfr_node, norm2, head)
|
pipeline = VGroup(emb_node, tfr_node, norm2, head)
|
||||||
|
|
||||||
|
# Stack effect for ×24 layers behind Transformer block
|
||||||
|
layer_stack = VGroup()
|
||||||
|
for i in range(3, 0, -1):
|
||||||
|
shadow = RoundedRectangle(
|
||||||
|
width=2.4, height=TFR_H, corner_radius=0.06,
|
||||||
|
stroke_color=BLUE_D, stroke_width=1.0, fill_opacity=0,
|
||||||
|
).move_to([CX_BLK + i * 0.04, Y_TFR - i * 0.04, 0])
|
||||||
|
layer_stack.add(shadow)
|
||||||
|
|
||||||
# Arrows between blocks
|
# Arrows between blocks
|
||||||
arrows = VGroup()
|
arrows = VGroup()
|
||||||
|
|
@ -446,9 +440,9 @@ class Transformer(Scene):
|
||||||
kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v)
|
kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v)
|
||||||
|
|
||||||
# ── Distribution builder ──
|
# ── Distribution builder ──
|
||||||
def build_dist(probs, y_center, max_w=2.8):
|
def build_dist(probs, y_center, max_w=3.0):
|
||||||
bars = VGroup(); lbls = VGroup()
|
bars = VGroup(); lbls = VGroup()
|
||||||
bh = 0.12; bg = 0.02; lx = CX_BLK - max_w / 2
|
bh = 0.18; bg = 0.03; lx = CX_BLK - max_w / 2
|
||||||
items = list(probs.items())
|
items = list(probs.items())
|
||||||
n = len(items)
|
n = len(items)
|
||||||
y_top = y_center + (n * bh + (n - 1) * bg) / 2
|
y_top = y_center + (n * bh + (n - 1) * bg) / 2
|
||||||
|
|
@ -459,7 +453,7 @@ class Transformer(Scene):
|
||||||
fill_color=interpolate_color(BLUE, RED, pct / 100),
|
fill_color=interpolate_color(BLUE, RED, pct / 100),
|
||||||
fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3)
|
fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3)
|
||||||
bar.move_to([lx + w / 2, y, 0])
|
bar.move_to([lx + w / 2, y, 0])
|
||||||
lbl = Text(f"{tok} {pct}%", font_size=7, color=WHITE)
|
lbl = Text(f"{tok} {pct}%", font_size=10, color=WHITE)
|
||||||
lbl.next_to(bar, RIGHT, buff=0.05)
|
lbl.next_to(bar, RIGHT, buff=0.05)
|
||||||
bars.add(bar); lbls.add(lbl)
|
bars.add(bar); lbls.add(lbl)
|
||||||
return VGroup(bars, lbls)
|
return VGroup(bars, lbls)
|
||||||
|
|
@ -472,22 +466,28 @@ class Transformer(Scene):
|
||||||
{"the": 60, "a": 12, "top": 8, "floor": 5, "<unk>": 15},
|
{"the": 60, "a": 12, "top": 8, "floor": 5, "<unk>": 15},
|
||||||
{"mat": 50, "table": 15, "chair": 8, "floor": 6, "<unk>": 21},
|
{"mat": 50, "table": 15, "chair": 8, "floor": 6, "<unk>": 21},
|
||||||
]
|
]
|
||||||
Y_DIST = -1.45
|
Y_DIST = -1.65
|
||||||
|
|
||||||
# ── Token sequence row ──
|
# ── Token sequence row ──
|
||||||
SX = -4.0; Y_SEQ = 2.0; GAP = 0.55
|
SX = -4.0; Y_SEQ = 2.4; GAP = 0.70
|
||||||
seq = VGroup()
|
seq = VGroup()
|
||||||
sos = tok_card("<s>", BLUE, BLUE).move_to([SX, Y_SEQ, 0])
|
sos = tok_card("<s>", BLUE, BLUE).move_to([SX, Y_SEQ, 0])
|
||||||
seq.add(sos)
|
seq.add(sos)
|
||||||
|
|
||||||
|
# Position labels under each token
|
||||||
|
pos_lbls = VGroup()
|
||||||
|
pos0 = Text("0", font_size=9, color=DARK_GRAY).next_to(sos, DOWN, buff=0.08)
|
||||||
|
pos_lbls.add(pos0)
|
||||||
|
|
||||||
# Step label (below everything)
|
# Step label (below everything)
|
||||||
step_lbl = Text("Step 0 — [<s>] → ?", font_size=9, color=GRAY)
|
step_lbl = Text("Step 0 — [<s>] → ?", font_size=9, color=GRAY)
|
||||||
step_lbl.move_to([0, -2.1, 0])
|
step_lbl.move_to([0, -2.1, 0])
|
||||||
|
|
||||||
# ── Show static elements ──
|
# ── Show static elements ──
|
||||||
self.play(FadeIn(pipeline), FadeIn(arrows))
|
self.play(FadeIn(pipeline), FadeIn(arrows))
|
||||||
|
self.play(FadeIn(layer_stack))
|
||||||
self.play(FadeIn(kv_group))
|
self.play(FadeIn(kv_group))
|
||||||
self.play(FadeIn(sos))
|
self.play(FadeIn(sos), Write(pos0))
|
||||||
self.play(Write(step_lbl))
|
self.play(Write(step_lbl))
|
||||||
self.wait(0.5)
|
self.wait(0.5)
|
||||||
|
|
||||||
|
|
@ -506,13 +506,12 @@ class Transformer(Scene):
|
||||||
self.play(FadeIn(in_arr, scale=0.5), run_time=0.1)
|
self.play(FadeIn(in_arr, scale=0.5), run_time=0.1)
|
||||||
|
|
||||||
# 3. Cascade through pipeline
|
# 3. Cascade through pipeline
|
||||||
pipes = [emb_node, norm1, tfr_node, norm2, head]
|
|
||||||
self.play(
|
self.play(
|
||||||
*[p[0].animate.set_fill_opacity(0.6) for p in pipes],
|
*[p[0].animate.set_fill_opacity(0.6) for p in pipeline],
|
||||||
run_time=0.12
|
run_time=0.12
|
||||||
)
|
)
|
||||||
self.play(
|
self.play(
|
||||||
*[p[0].animate.set_fill_opacity(0.25) for p in pipes],
|
*[p[0].animate.set_fill_opacity(0.25) for p in pipeline],
|
||||||
run_time=0.1
|
run_time=0.1
|
||||||
)
|
)
|
||||||
self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08)
|
self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08)
|
||||||
|
|
@ -539,12 +538,16 @@ class Transformer(Scene):
|
||||||
|
|
||||||
# 7. Token rises to join sequence at top
|
# 7. Token rises to join sequence at top
|
||||||
self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3)
|
self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3)
|
||||||
|
pos_lbl = Text(str(i), font_size=9, color=DARK_GRAY)
|
||||||
|
pos_lbl.next_to(pred, DOWN, buff=0.08)
|
||||||
self.play(
|
self.play(
|
||||||
pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5),
|
pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5),
|
||||||
pred[1].animate.set_color(WHITE),
|
pred[1].animate.set_color(WHITE),
|
||||||
|
Write(pos_lbl),
|
||||||
FadeOut(sample_hl), FadeOut(samp_lbl),
|
FadeOut(sample_hl), FadeOut(samp_lbl),
|
||||||
)
|
)
|
||||||
seq.add(pred)
|
seq.add(pred)
|
||||||
|
pos_lbls.add(pos_lbl)
|
||||||
|
|
||||||
# 8. KV Cache: add K,V of this predicted token
|
# 8. KV Cache: add K,V of this predicted token
|
||||||
kv_x = KV_X + i * (kv_size + kv_gap)
|
kv_x = KV_X + i * (kv_size + kv_gap)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue