fix: bottom spacing, remove specs card, full formula first with no ellipsis in steps
This commit is contained in:
parent
eeaf0a5a16
commit
e7d736a3b0
117
transformer.py
117
transformer.py
|
|
@ -26,9 +26,9 @@ class Transformer(Scene):
|
||||||
|
|
||||||
# ── Layout ──
|
# ── Layout ──
|
||||||
inp = Text("x (hidden states)", font_size=15, color=GRAY)
|
inp = Text("x (hidden states)", font_size=15, color=GRAY)
|
||||||
inp.move_to(UP * 2.8)
|
inp.move_to(UP * 3.2)
|
||||||
|
|
||||||
y1 = 1.5
|
y1 = 1.9
|
||||||
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
|
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
|
||||||
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
|
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
|
||||||
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
|
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
|
||||||
|
|
@ -36,17 +36,17 @@ class Transformer(Scene):
|
||||||
k_grp.move_to(UP * y1)
|
k_grp.move_to(UP * y1)
|
||||||
v_grp.move_to(RIGHT * 3.0 + UP * y1)
|
v_grp.move_to(RIGHT * 3.0 + UP * y1)
|
||||||
|
|
||||||
y2 = 0.0
|
y2 = 0.4
|
||||||
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
|
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
|
||||||
repeat_grp.move_to(UP * y2)
|
repeat_grp.move_to(UP * y2)
|
||||||
|
|
||||||
y3 = -1.6
|
y3 = -1.2
|
||||||
sdpa_grp = mk(
|
sdpa_grp = mk(
|
||||||
"Scaled Dot-Product\nAttention Q·Kᵀ/√d", BLUE, 2.8, 0.74, 10
|
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
|
||||||
)
|
)
|
||||||
sdpa_grp.move_to(UP * y3)
|
sdpa_grp.move_to(UP * y3)
|
||||||
|
|
||||||
y4 = -3.0
|
y4 = -2.6
|
||||||
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
|
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
|
||||||
o_grp.move_to(UP * y4)
|
o_grp.move_to(UP * y4)
|
||||||
|
|
||||||
|
|
@ -172,6 +172,7 @@ class Transformer(Scene):
|
||||||
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
|
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
|
||||||
self.play(Create(gqa_h), Write(gqa_t))
|
self.play(Create(gqa_h), Write(gqa_t))
|
||||||
self.wait(1.8)
|
self.wait(1.8)
|
||||||
|
self.play(FadeOut(gqa_h), FadeOut(gqa_t))
|
||||||
|
|
||||||
# ── Repeat KV highlight ──
|
# ── Repeat KV highlight ──
|
||||||
kv_h = SurroundingRectangle(
|
kv_h = SurroundingRectangle(
|
||||||
|
|
@ -190,96 +191,68 @@ class Transformer(Scene):
|
||||||
*[FadeOut(g) for g in all_boxes],
|
*[FadeOut(g) for g in all_boxes],
|
||||||
FadeOut(all_lines),
|
FadeOut(all_lines),
|
||||||
FadeOut(kv_h), FadeOut(kv_t),
|
FadeOut(kv_h), FadeOut(kv_t),
|
||||||
FadeOut(gqa_h), FadeOut(gqa_t),
|
|
||||||
FadeOut(inp), FadeOut(out), FadeOut(title),
|
FadeOut(inp), FadeOut(out), FadeOut(title),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Specs card ──
|
|
||||||
st = Text("Model Specifications", font_size=36, color=BLUE)
|
|
||||||
st.to_edge(UP, buff=0.5)
|
|
||||||
rows_data = [
|
|
||||||
("Parameters", "~1.0B"),
|
|
||||||
("Layers", "24 × DecoderBlock"),
|
|
||||||
("Hidden Dim", "1536"),
|
|
||||||
("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"),
|
|
||||||
("Head Dim", "64"),
|
|
||||||
("FFN Dim", "4608 (SwiGLU)"),
|
|
||||||
("Max Length", "2048"),
|
|
||||||
("Precision", "bfloat16"),
|
|
||||||
]
|
|
||||||
table = VGroup()
|
|
||||||
for label, value in rows_data:
|
|
||||||
row = VGroup(
|
|
||||||
Text(label + ":", font_size=15, color=GRAY),
|
|
||||||
Text(value, font_size=15, color=WHITE),
|
|
||||||
).arrange(RIGHT, buff=0.4, aligned_edge=LEFT)
|
|
||||||
table.add(row)
|
|
||||||
table.arrange(DOWN, buff=0.1, aligned_edge=LEFT)
|
|
||||||
table.next_to(st, DOWN, buff=0.4)
|
|
||||||
self.play(Write(st), Write(table))
|
|
||||||
self.wait(2)
|
|
||||||
self.play(FadeOut(st), FadeOut(table))
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
# 12. Q / K / V — what do they mean?
|
# 12. Scaled Dot-Product Attention — full formula + breakdown
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE)
|
qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE)
|
||||||
qkv_title.to_edge(UP, buff=0.35)
|
qkv_title.to_edge(UP, buff=0.35)
|
||||||
self.play(Write(qkv_title))
|
self.play(Write(qkv_title))
|
||||||
self.wait(0.2)
|
|
||||||
|
|
||||||
q_txt = Text("Q = Query", font_size=24, color=YELLOW)
|
# Full formula first, stays on screen
|
||||||
k_txt = Text("K = Key", font_size=24, color=ORANGE)
|
|
||||||
v_txt = Text("V = Value", font_size=24, color=GREEN)
|
|
||||||
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
|
|
||||||
qkv_labels.next_to(qkv_title, DOWN, buff=0.6)
|
|
||||||
self.play(Write(qkv_labels))
|
|
||||||
|
|
||||||
q_desc = Text("\"what am I looking for?\"", font_size=12, color=YELLOW) \
|
|
||||||
.next_to(q_txt, DOWN, buff=0.12)
|
|
||||||
k_desc = Text("\"what do I have?\"", font_size=12, color=ORANGE) \
|
|
||||||
.next_to(k_txt, DOWN, buff=0.12)
|
|
||||||
v_desc = Text("\"what do I contribute?\"", font_size=12, color=GREEN) \
|
|
||||||
.next_to(v_txt, DOWN, buff=0.12)
|
|
||||||
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
|
|
||||||
self.wait(2.0)
|
|
||||||
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
|
|
||||||
|
|
||||||
# ── Full formula ──
|
|
||||||
full_eq = MathTex(
|
full_eq = MathTex(
|
||||||
r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V",
|
r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V",
|
||||||
font_size=36, color=WHITE,
|
font_size=36, color=WHITE,
|
||||||
)
|
)
|
||||||
full_eq.next_to(qkv_title, DOWN, buff=0.6)
|
full_eq.next_to(qkv_title, DOWN, buff=0.5)
|
||||||
self.play(Write(full_eq))
|
self.play(Write(full_eq))
|
||||||
self.wait(1.2)
|
self.wait(0.6)
|
||||||
|
|
||||||
# ── Step-by-step decomposition ──
|
# Q / K / V — brief meanings
|
||||||
|
q_txt = Text("Q = Query", font_size=18, color=YELLOW)
|
||||||
|
k_txt = Text("K = Key", font_size=18, color=ORANGE)
|
||||||
|
v_txt = Text("V = Value", font_size=18, color=GREEN)
|
||||||
|
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
|
||||||
|
qkv_labels.next_to(full_eq, DOWN, buff=0.5)
|
||||||
|
self.play(Write(qkv_labels))
|
||||||
|
|
||||||
|
q_desc = Text("\"what am I looking for?\"", font_size=11, color=YELLOW) \
|
||||||
|
.next_to(q_txt, DOWN, buff=0.10)
|
||||||
|
k_desc = Text("\"what do I have?\"", font_size=11, color=ORANGE) \
|
||||||
|
.next_to(k_txt, DOWN, buff=0.10)
|
||||||
|
v_desc = Text("\"what do I contribute?\"", font_size=11, color=GREEN) \
|
||||||
|
.next_to(v_txt, DOWN, buff=0.10)
|
||||||
|
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
|
||||||
|
self.wait(2.0)
|
||||||
|
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
|
||||||
|
|
||||||
|
# Step-by-step decomposition — full formula stays visible above
|
||||||
steps = [
|
steps = [
|
||||||
(MathTex(r"\text{(1) } S = QK^\top", font_size=28, color=YELLOW),
|
(MathTex(r"\text{(1) } S = QK^\top", font_size=26, color=YELLOW),
|
||||||
Text("score matrix — pairwise token similarity", font_size=13, color=GRAY)),
|
Text("score matrix — pairwise token similarity", font_size=12, color=GRAY)),
|
||||||
(MathTex(r"\text{(2) } S / \sqrt{d_k}", font_size=28, color=ORANGE),
|
(MathTex(r"\text{(2) } S \mathbin{/} \sqrt{d_k}", font_size=26, color=ORANGE),
|
||||||
Text("scale — prevents gradient explosion", font_size=13, color=GRAY)),
|
Text("scale — prevents gradient explosion", font_size=12, color=GRAY)),
|
||||||
(MathTex(r"\text{(3) } \operatorname{softmax}(\cdots)", font_size=28, color=GREEN),
|
(MathTex(r"\text{(3) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right)", font_size=26, color=GREEN),
|
||||||
Text("normalize — each row sums to 1 (probability)", font_size=13, color=GRAY)),
|
Text("normalize — each row sums to 1 (probability)", font_size=12, color=GRAY)),
|
||||||
(MathTex(r"\text{(4) } \cdots \cdot V", font_size=28, color=BLUE),
|
(MathTex(r"\text{(4) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right) \cdot V", font_size=26, color=BLUE),
|
||||||
Text("weighted sum — aggregate values by attention", font_size=13, color=GRAY)),
|
Text("weighted sum — aggregate values by attention", font_size=12, color=GRAY)),
|
||||||
]
|
]
|
||||||
|
|
||||||
step_group = VGroup()
|
step_group = VGroup()
|
||||||
step_descs = VGroup()
|
steps_mobj = VGroup()
|
||||||
for eq, desc in steps:
|
for eq, desc in steps:
|
||||||
sg = VGroup(eq, desc).arrange(DOWN, buff=0.08)
|
sg = VGroup(eq, desc).arrange(DOWN, buff=0.06)
|
||||||
step_group.add(sg)
|
step_group.add(sg)
|
||||||
step_descs.add(desc)
|
steps_mobj.add(eq)
|
||||||
step_group.arrange(DOWN, buff=0.25, aligned_edge=LEFT)
|
step_group.arrange(DOWN, buff=0.22, aligned_edge=LEFT)
|
||||||
step_group.next_to(full_eq, DOWN, buff=0.7)
|
step_group.next_to(full_eq, DOWN, buff=0.6)
|
||||||
|
|
||||||
self.play(FadeOut(full_eq))
|
|
||||||
for sg in step_group:
|
for sg in step_group:
|
||||||
self.play(Write(sg), run_time=0.3)
|
self.play(Write(sg), run_time=0.3)
|
||||||
self.wait(2.0)
|
self.wait(2.5)
|
||||||
self.play(FadeOut(step_group))
|
self.play(FadeOut(step_group), FadeOut(full_eq))
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
# 13. Attention score heatmap — concrete example
|
# 13. Attention score heatmap — concrete example
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue