fix: bottom spacing, remove specs card, full formula first with no ellipsis in steps

This commit is contained in:
ViperEkura 2026-05-07 09:28:07 +08:00
parent eeaf0a5a16
commit e7d736a3b0
1 changed files with 45 additions and 72 deletions

View File

@ -26,9 +26,9 @@ class Transformer(Scene):
# ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp.move_to(UP * 2.8)
inp.move_to(UP * 3.2)
y1 = 1.5
y1 = 1.9
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
@ -36,17 +36,17 @@ class Transformer(Scene):
k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1)
y2 = 0.0
y2 = 0.4
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp.move_to(UP * y2)
y3 = -1.6
y3 = -1.2
sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·K/√d", BLUE, 2.8, 0.74, 10
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
)
sdpa_grp.move_to(UP * y3)
y4 = -3.0
y4 = -2.6
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp.move_to(UP * y4)
@ -172,6 +172,7 @@ class Transformer(Scene):
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
self.play(Create(gqa_h), Write(gqa_t))
self.wait(1.8)
self.play(FadeOut(gqa_h), FadeOut(gqa_t))
# ── Repeat KV highlight ──
kv_h = SurroundingRectangle(
@ -190,96 +191,68 @@ class Transformer(Scene):
*[FadeOut(g) for g in all_boxes],
FadeOut(all_lines),
FadeOut(kv_h), FadeOut(kv_t),
FadeOut(gqa_h), FadeOut(gqa_t),
FadeOut(inp), FadeOut(out), FadeOut(title),
)
# ── Specs card ──
st = Text("Model Specifications", font_size=36, color=BLUE)
st.to_edge(UP, buff=0.5)
rows_data = [
("Parameters", "~1.0B"),
("Layers", "24 × DecoderBlock"),
("Hidden Dim", "1536"),
("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"),
("Head Dim", "64"),
("FFN Dim", "4608 (SwiGLU)"),
("Max Length", "2048"),
("Precision", "bfloat16"),
]
table = VGroup()
for label, value in rows_data:
row = VGroup(
Text(label + ":", font_size=15, color=GRAY),
Text(value, font_size=15, color=WHITE),
).arrange(RIGHT, buff=0.4, aligned_edge=LEFT)
table.add(row)
table.arrange(DOWN, buff=0.1, aligned_edge=LEFT)
table.next_to(st, DOWN, buff=0.4)
self.play(Write(st), Write(table))
self.wait(2)
self.play(FadeOut(st), FadeOut(table))
# ═══════════════════════════════════════════════════
# 12. Q / K / V — what do they mean?
# 12. Scaled Dot-Product Attention — full formula + breakdown
# ═══════════════════════════════════════════════════
qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE)
qkv_title.to_edge(UP, buff=0.35)
self.play(Write(qkv_title))
self.wait(0.2)
q_txt = Text("Q = Query", font_size=24, color=YELLOW)
k_txt = Text("K = Key", font_size=24, color=ORANGE)
v_txt = Text("V = Value", font_size=24, color=GREEN)
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
qkv_labels.next_to(qkv_title, DOWN, buff=0.6)
self.play(Write(qkv_labels))
q_desc = Text("\"what am I looking for?\"", font_size=12, color=YELLOW) \
.next_to(q_txt, DOWN, buff=0.12)
k_desc = Text("\"what do I have?\"", font_size=12, color=ORANGE) \
.next_to(k_txt, DOWN, buff=0.12)
v_desc = Text("\"what do I contribute?\"", font_size=12, color=GREEN) \
.next_to(v_txt, DOWN, buff=0.12)
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
self.wait(2.0)
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
# ── Full formula ──
# Full formula first, stays on screen
full_eq = MathTex(
r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V",
font_size=36, color=WHITE,
)
full_eq.next_to(qkv_title, DOWN, buff=0.6)
full_eq.next_to(qkv_title, DOWN, buff=0.5)
self.play(Write(full_eq))
self.wait(1.2)
self.wait(0.6)
# ── Step-by-step decomposition ──
# Q / K / V — brief meanings
q_txt = Text("Q = Query", font_size=18, color=YELLOW)
k_txt = Text("K = Key", font_size=18, color=ORANGE)
v_txt = Text("V = Value", font_size=18, color=GREEN)
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
qkv_labels.next_to(full_eq, DOWN, buff=0.5)
self.play(Write(qkv_labels))
q_desc = Text("\"what am I looking for?\"", font_size=11, color=YELLOW) \
.next_to(q_txt, DOWN, buff=0.10)
k_desc = Text("\"what do I have?\"", font_size=11, color=ORANGE) \
.next_to(k_txt, DOWN, buff=0.10)
v_desc = Text("\"what do I contribute?\"", font_size=11, color=GREEN) \
.next_to(v_txt, DOWN, buff=0.10)
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
self.wait(2.0)
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
# Step-by-step decomposition — full formula stays visible above
steps = [
(MathTex(r"\text{(1) } S = QK^\top", font_size=28, color=YELLOW),
Text("score matrix — pairwise token similarity", font_size=13, color=GRAY)),
(MathTex(r"\text{(2) } S / \sqrt{d_k}", font_size=28, color=ORANGE),
Text("scale — prevents gradient explosion", font_size=13, color=GRAY)),
(MathTex(r"\text{(3) } \operatorname{softmax}(\cdots)", font_size=28, color=GREEN),
Text("normalize — each row sums to 1 (probability)", font_size=13, color=GRAY)),
(MathTex(r"\text{(4) } \cdots \cdot V", font_size=28, color=BLUE),
Text("weighted sum — aggregate values by attention", font_size=13, color=GRAY)),
(MathTex(r"\text{(1) } S = QK^\top", font_size=26, color=YELLOW),
Text("score matrix — pairwise token similarity", font_size=12, color=GRAY)),
(MathTex(r"\text{(2) } S \mathbin{/} \sqrt{d_k}", font_size=26, color=ORANGE),
Text("scale — prevents gradient explosion", font_size=12, color=GRAY)),
(MathTex(r"\text{(3) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right)", font_size=26, color=GREEN),
Text("normalize — each row sums to 1 (probability)", font_size=12, color=GRAY)),
(MathTex(r"\text{(4) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right) \cdot V", font_size=26, color=BLUE),
Text("weighted sum — aggregate values by attention", font_size=12, color=GRAY)),
]
step_group = VGroup()
step_descs = VGroup()
steps_mobj = VGroup()
for eq, desc in steps:
sg = VGroup(eq, desc).arrange(DOWN, buff=0.08)
sg = VGroup(eq, desc).arrange(DOWN, buff=0.06)
step_group.add(sg)
step_descs.add(desc)
step_group.arrange(DOWN, buff=0.25, aligned_edge=LEFT)
step_group.next_to(full_eq, DOWN, buff=0.7)
steps_mobj.add(eq)
step_group.arrange(DOWN, buff=0.22, aligned_edge=LEFT)
step_group.next_to(full_eq, DOWN, buff=0.6)
self.play(FadeOut(full_eq))
for sg in step_group:
self.play(Write(sg), run_time=0.3)
self.wait(2.0)
self.play(FadeOut(step_group))
self.wait(2.5)
self.play(FadeOut(step_group), FadeOut(full_eq))
# ═══════════════════════════════════════════════════
# 13. Attention score heatmap — concrete example