fix: bottom spacing, remove specs card, full formula first with no ellipsis in steps

This commit is contained in:
ViperEkura 2026-05-07 09:28:07 +08:00
parent eeaf0a5a16
commit e7d736a3b0
1 changed files with 45 additions and 72 deletions

View File

@ -26,9 +26,9 @@ class Transformer(Scene):
# ── Layout ── # ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY) inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp.move_to(UP * 2.8) inp.move_to(UP * 3.2)
y1 = 1.5 y1 = 1.9
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW) q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW) k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW) v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
@ -36,17 +36,17 @@ class Transformer(Scene):
k_grp.move_to(UP * y1) k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1) v_grp.move_to(RIGHT * 3.0 + UP * y1)
y2 = 0.0 y2 = 0.4
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10) repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp.move_to(UP * y2) repeat_grp.move_to(UP * y2)
y3 = -1.6 y3 = -1.2
sdpa_grp = mk( sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·K/√d", BLUE, 2.8, 0.74, 10 "Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
) )
sdpa_grp.move_to(UP * y3) sdpa_grp.move_to(UP * y3)
y4 = -3.0 y4 = -2.6
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10) o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp.move_to(UP * y4) o_grp.move_to(UP * y4)
@ -172,6 +172,7 @@ class Transformer(Scene):
gqa_t.next_to(gqa_h, RIGHT, buff=0.5) gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
self.play(Create(gqa_h), Write(gqa_t)) self.play(Create(gqa_h), Write(gqa_t))
self.wait(1.8) self.wait(1.8)
self.play(FadeOut(gqa_h), FadeOut(gqa_t))
# ── Repeat KV highlight ── # ── Repeat KV highlight ──
kv_h = SurroundingRectangle( kv_h = SurroundingRectangle(
@ -190,96 +191,68 @@ class Transformer(Scene):
*[FadeOut(g) for g in all_boxes], *[FadeOut(g) for g in all_boxes],
FadeOut(all_lines), FadeOut(all_lines),
FadeOut(kv_h), FadeOut(kv_t), FadeOut(kv_h), FadeOut(kv_t),
FadeOut(gqa_h), FadeOut(gqa_t),
FadeOut(inp), FadeOut(out), FadeOut(title), FadeOut(inp), FadeOut(out), FadeOut(title),
) )
# ── Specs card ──
st = Text("Model Specifications", font_size=36, color=BLUE)
st.to_edge(UP, buff=0.5)
rows_data = [
("Parameters", "~1.0B"),
("Layers", "24 × DecoderBlock"),
("Hidden Dim", "1536"),
("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"),
("Head Dim", "64"),
("FFN Dim", "4608 (SwiGLU)"),
("Max Length", "2048"),
("Precision", "bfloat16"),
]
table = VGroup()
for label, value in rows_data:
row = VGroup(
Text(label + ":", font_size=15, color=GRAY),
Text(value, font_size=15, color=WHITE),
).arrange(RIGHT, buff=0.4, aligned_edge=LEFT)
table.add(row)
table.arrange(DOWN, buff=0.1, aligned_edge=LEFT)
table.next_to(st, DOWN, buff=0.4)
self.play(Write(st), Write(table))
self.wait(2)
self.play(FadeOut(st), FadeOut(table))
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# 12. Q / K / V — what do they mean? # 12. Scaled Dot-Product Attention — full formula + breakdown
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE) qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE)
qkv_title.to_edge(UP, buff=0.35) qkv_title.to_edge(UP, buff=0.35)
self.play(Write(qkv_title)) self.play(Write(qkv_title))
self.wait(0.2)
q_txt = Text("Q = Query", font_size=24, color=YELLOW) # Full formula first, stays on screen
k_txt = Text("K = Key", font_size=24, color=ORANGE)
v_txt = Text("V = Value", font_size=24, color=GREEN)
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
qkv_labels.next_to(qkv_title, DOWN, buff=0.6)
self.play(Write(qkv_labels))
q_desc = Text("\"what am I looking for?\"", font_size=12, color=YELLOW) \
.next_to(q_txt, DOWN, buff=0.12)
k_desc = Text("\"what do I have?\"", font_size=12, color=ORANGE) \
.next_to(k_txt, DOWN, buff=0.12)
v_desc = Text("\"what do I contribute?\"", font_size=12, color=GREEN) \
.next_to(v_txt, DOWN, buff=0.12)
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
self.wait(2.0)
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
# ── Full formula ──
full_eq = MathTex( full_eq = MathTex(
r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V", r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V",
font_size=36, color=WHITE, font_size=36, color=WHITE,
) )
full_eq.next_to(qkv_title, DOWN, buff=0.6) full_eq.next_to(qkv_title, DOWN, buff=0.5)
self.play(Write(full_eq)) self.play(Write(full_eq))
self.wait(1.2) self.wait(0.6)
# ── Step-by-step decomposition ── # Q / K / V — brief meanings
q_txt = Text("Q = Query", font_size=18, color=YELLOW)
k_txt = Text("K = Key", font_size=18, color=ORANGE)
v_txt = Text("V = Value", font_size=18, color=GREEN)
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
qkv_labels.next_to(full_eq, DOWN, buff=0.5)
self.play(Write(qkv_labels))
q_desc = Text("\"what am I looking for?\"", font_size=11, color=YELLOW) \
.next_to(q_txt, DOWN, buff=0.10)
k_desc = Text("\"what do I have?\"", font_size=11, color=ORANGE) \
.next_to(k_txt, DOWN, buff=0.10)
v_desc = Text("\"what do I contribute?\"", font_size=11, color=GREEN) \
.next_to(v_txt, DOWN, buff=0.10)
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
self.wait(2.0)
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
# Step-by-step decomposition — full formula stays visible above
steps = [ steps = [
(MathTex(r"\text{(1) } S = QK^\top", font_size=28, color=YELLOW), (MathTex(r"\text{(1) } S = QK^\top", font_size=26, color=YELLOW),
Text("score matrix — pairwise token similarity", font_size=13, color=GRAY)), Text("score matrix — pairwise token similarity", font_size=12, color=GRAY)),
(MathTex(r"\text{(2) } S / \sqrt{d_k}", font_size=28, color=ORANGE), (MathTex(r"\text{(2) } S \mathbin{/} \sqrt{d_k}", font_size=26, color=ORANGE),
Text("scale — prevents gradient explosion", font_size=13, color=GRAY)), Text("scale — prevents gradient explosion", font_size=12, color=GRAY)),
(MathTex(r"\text{(3) } \operatorname{softmax}(\cdots)", font_size=28, color=GREEN), (MathTex(r"\text{(3) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right)", font_size=26, color=GREEN),
Text("normalize — each row sums to 1 (probability)", font_size=13, color=GRAY)), Text("normalize — each row sums to 1 (probability)", font_size=12, color=GRAY)),
(MathTex(r"\text{(4) } \cdots \cdot V", font_size=28, color=BLUE), (MathTex(r"\text{(4) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right) \cdot V", font_size=26, color=BLUE),
Text("weighted sum — aggregate values by attention", font_size=13, color=GRAY)), Text("weighted sum — aggregate values by attention", font_size=12, color=GRAY)),
] ]
step_group = VGroup() step_group = VGroup()
step_descs = VGroup() steps_mobj = VGroup()
for eq, desc in steps: for eq, desc in steps:
sg = VGroup(eq, desc).arrange(DOWN, buff=0.08) sg = VGroup(eq, desc).arrange(DOWN, buff=0.06)
step_group.add(sg) step_group.add(sg)
step_descs.add(desc) steps_mobj.add(eq)
step_group.arrange(DOWN, buff=0.25, aligned_edge=LEFT) step_group.arrange(DOWN, buff=0.22, aligned_edge=LEFT)
step_group.next_to(full_eq, DOWN, buff=0.7) step_group.next_to(full_eq, DOWN, buff=0.6)
self.play(FadeOut(full_eq))
for sg in step_group: for sg in step_group:
self.play(Write(sg), run_time=0.3) self.play(Write(sg), run_time=0.3)
self.wait(2.0) self.wait(2.5)
self.play(FadeOut(step_group)) self.play(FadeOut(step_group), FadeOut(full_eq))
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# 13. Attention score heatmap — concrete example # 13. Attention score heatmap — concrete example