feat: SDPA formula breakdown + attention score heatmap with per-cell causal mask
This commit is contained in:
parent
ba100c19d1
commit
eeaf0a5a16
175
transformer.py
175
transformer.py
|
|
@ -220,6 +220,181 @@ class Transformer(Scene):
|
||||||
self.wait(2)
|
self.wait(2)
|
||||||
self.play(FadeOut(st), FadeOut(table))
|
self.play(FadeOut(st), FadeOut(table))
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
# 12. Q / K / V — what do they mean?
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE)
|
||||||
|
qkv_title.to_edge(UP, buff=0.35)
|
||||||
|
self.play(Write(qkv_title))
|
||||||
|
self.wait(0.2)
|
||||||
|
|
||||||
|
q_txt = Text("Q = Query", font_size=24, color=YELLOW)
|
||||||
|
k_txt = Text("K = Key", font_size=24, color=ORANGE)
|
||||||
|
v_txt = Text("V = Value", font_size=24, color=GREEN)
|
||||||
|
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
|
||||||
|
qkv_labels.next_to(qkv_title, DOWN, buff=0.6)
|
||||||
|
self.play(Write(qkv_labels))
|
||||||
|
|
||||||
|
q_desc = Text("\"what am I looking for?\"", font_size=12, color=YELLOW) \
|
||||||
|
.next_to(q_txt, DOWN, buff=0.12)
|
||||||
|
k_desc = Text("\"what do I have?\"", font_size=12, color=ORANGE) \
|
||||||
|
.next_to(k_txt, DOWN, buff=0.12)
|
||||||
|
v_desc = Text("\"what do I contribute?\"", font_size=12, color=GREEN) \
|
||||||
|
.next_to(v_txt, DOWN, buff=0.12)
|
||||||
|
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
|
||||||
|
self.wait(2.0)
|
||||||
|
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
|
||||||
|
|
||||||
|
# ── Full formula ──
|
||||||
|
full_eq = MathTex(
|
||||||
|
r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V",
|
||||||
|
font_size=36, color=WHITE,
|
||||||
|
)
|
||||||
|
full_eq.next_to(qkv_title, DOWN, buff=0.6)
|
||||||
|
self.play(Write(full_eq))
|
||||||
|
self.wait(1.2)
|
||||||
|
|
||||||
|
# ── Step-by-step decomposition ──
|
||||||
|
steps = [
|
||||||
|
(MathTex(r"\text{(1) } S = QK^\top", font_size=28, color=YELLOW),
|
||||||
|
Text("score matrix — pairwise token similarity", font_size=13, color=GRAY)),
|
||||||
|
(MathTex(r"\text{(2) } S / \sqrt{d_k}", font_size=28, color=ORANGE),
|
||||||
|
Text("scale — prevents gradient explosion", font_size=13, color=GRAY)),
|
||||||
|
(MathTex(r"\text{(3) } \operatorname{softmax}(\cdots)", font_size=28, color=GREEN),
|
||||||
|
Text("normalize — each row sums to 1 (probability)", font_size=13, color=GRAY)),
|
||||||
|
(MathTex(r"\text{(4) } \cdots \cdot V", font_size=28, color=BLUE),
|
||||||
|
Text("weighted sum — aggregate values by attention", font_size=13, color=GRAY)),
|
||||||
|
]
|
||||||
|
|
||||||
|
step_group = VGroup()
|
||||||
|
step_descs = VGroup()
|
||||||
|
for eq, desc in steps:
|
||||||
|
sg = VGroup(eq, desc).arrange(DOWN, buff=0.08)
|
||||||
|
step_group.add(sg)
|
||||||
|
step_descs.add(desc)
|
||||||
|
step_group.arrange(DOWN, buff=0.25, aligned_edge=LEFT)
|
||||||
|
step_group.next_to(full_eq, DOWN, buff=0.7)
|
||||||
|
|
||||||
|
self.play(FadeOut(full_eq))
|
||||||
|
for sg in step_group:
|
||||||
|
self.play(Write(sg), run_time=0.3)
|
||||||
|
self.wait(2.0)
|
||||||
|
self.play(FadeOut(step_group))
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
# 13. Attention score heatmap — concrete example
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
hm_title = Text("Attention Score Heatmap", font_size=34, color=BLUE)
|
||||||
|
hm_title.to_edge(UP, buff=0.35)
|
||||||
|
hm_sub = Text("\"The cat sat on the mat\" — causal, per-token attention weights",
|
||||||
|
font_size=14, color=GRAY).next_to(hm_title, DOWN, buff=0.12)
|
||||||
|
self.play(FadeOut(qkv_title), Write(hm_title), Write(hm_sub))
|
||||||
|
|
||||||
|
tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
|
||||||
|
n = len(tokens)
|
||||||
|
cell_size = 0.52
|
||||||
|
gap = 0.04
|
||||||
|
grid_high = n * cell_size + (n - 1) * gap
|
||||||
|
grid_left = -grid_high / 2
|
||||||
|
grid_top = 1.4
|
||||||
|
|
||||||
|
# attention weights (after softmax + causal mask)
|
||||||
|
weights = [
|
||||||
|
[1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
|
||||||
|
[0.05, 0.95, 0.00, 0.00, 0.00, 0.00, 0.00],
|
||||||
|
[0.02, 0.20, 0.78, 0.00, 0.00, 0.00, 0.00],
|
||||||
|
[0.01, 0.05, 0.40, 0.54, 0.00, 0.00, 0.00],
|
||||||
|
[0.00, 0.02, 0.07, 0.35, 0.56, 0.00, 0.00],
|
||||||
|
[0.00, 0.01, 0.03, 0.10, 0.30, 0.56, 0.00],
|
||||||
|
[0.00, 0.00, 0.01, 0.05, 0.12, 0.35, 0.47],
|
||||||
|
]
|
||||||
|
|
||||||
|
cells = VGroup()
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(n):
|
||||||
|
w = weights[i][j]
|
||||||
|
if j > i:
|
||||||
|
color = DARK_GRAY
|
||||||
|
fill_op = 0.15
|
||||||
|
elif w < 0.001:
|
||||||
|
color = DARKER_GRAY
|
||||||
|
fill_op = 0.2
|
||||||
|
else:
|
||||||
|
color = interpolate_color(BLUE, RED, w)
|
||||||
|
fill_op = 0.75
|
||||||
|
sq = Square(
|
||||||
|
side_length=cell_size, fill_color=color,
|
||||||
|
fill_opacity=fill_op, stroke_width=0.5,
|
||||||
|
stroke_color=GRAY,
|
||||||
|
)
|
||||||
|
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
||||||
|
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
||||||
|
sq.move_to([x, y, 0])
|
||||||
|
cells.add(sq)
|
||||||
|
self.play(FadeIn(sq, scale=0.6), run_time=0.015)
|
||||||
|
|
||||||
|
# row labels (query) on the left
|
||||||
|
row_lbls = VGroup()
|
||||||
|
for i, tok in enumerate(tokens):
|
||||||
|
lbl = Text(tok, font_size=12, color=GRAY)
|
||||||
|
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
||||||
|
lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08)
|
||||||
|
row_lbls.add(lbl)
|
||||||
|
q_label = Text("Q", font_size=11, color=WHITE, weight=BOLD)
|
||||||
|
q_label.move_to(row_lbls[0].get_left() + LEFT * 0.3).shift(UP * 0.15)
|
||||||
|
self.play(*[Write(l) for l in row_lbls], Write(q_label))
|
||||||
|
|
||||||
|
# column labels (key) on top
|
||||||
|
col_lbls = VGroup()
|
||||||
|
for j, tok in enumerate(tokens):
|
||||||
|
lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6)
|
||||||
|
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
||||||
|
lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04)
|
||||||
|
col_lbls.add(lbl)
|
||||||
|
k_label = Text("K", font_size=11, color=WHITE, weight=BOLD)
|
||||||
|
k_label.next_to(col_lbls[0], UP, buff=0.06)
|
||||||
|
self.play(*[Write(l) for l in col_lbls], Write(k_label))
|
||||||
|
self.wait(1.0)
|
||||||
|
|
||||||
|
# causal mask — per-cell red overlay aligned to grid
|
||||||
|
mask_overlays = VGroup()
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(n):
|
||||||
|
if j > i:
|
||||||
|
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
||||||
|
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
||||||
|
sq = Square(
|
||||||
|
side_length=cell_size, fill_color=RED,
|
||||||
|
fill_opacity=0.10, stroke_width=0.5,
|
||||||
|
stroke_color=RED, stroke_opacity=0.3,
|
||||||
|
)
|
||||||
|
sq.move_to([x, y, 0])
|
||||||
|
mask_overlays.add(sq)
|
||||||
|
causal_txt = Text("causal mask\n(future tokens hidden)", font_size=11, color=RED) \
|
||||||
|
.next_to(cells[6], UP, buff=0.25).align_to(cells[6], RIGHT)
|
||||||
|
self.play(FadeIn(mask_overlays), Write(causal_txt))
|
||||||
|
self.wait(1.5)
|
||||||
|
self.play(FadeOut(mask_overlays), FadeOut(causal_txt))
|
||||||
|
|
||||||
|
# highlight key patterns
|
||||||
|
h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04)
|
||||||
|
h2 = SurroundingRectangle(cells[3 * n + 2], color=ORANGE, stroke_width=2, buff=0.04)
|
||||||
|
h3 = SurroundingRectangle(cells[4 * n + 3], color=ORANGE, stroke_width=2, buff=0.04)
|
||||||
|
h4 = SurroundingRectangle(cells[5 * n + 4], color=ORANGE, stroke_width=2, buff=0.04)
|
||||||
|
h5 = SurroundingRectangle(cells[6 * n + 5], color=ORANGE, stroke_width=2, buff=0.04)
|
||||||
|
hl_text = Text("previous token attends to next\n(causal sequence learning)", font_size=11, color=ORANGE) \
|
||||||
|
.next_to(cells[(n - 1) * n + (n - 1)], RIGHT, buff=0.8)
|
||||||
|
self.play(Create(h1), Create(h2), Create(h3), Create(h4), Create(h5), Write(hl_text))
|
||||||
|
self.wait(2.0)
|
||||||
|
self.play(FadeOut(h1), FadeOut(h2), FadeOut(h3), FadeOut(h4), FadeOut(h5), FadeOut(hl_text))
|
||||||
|
|
||||||
|
# fade all heatmap
|
||||||
|
self.play(
|
||||||
|
FadeOut(hm_title), FadeOut(hm_sub),
|
||||||
|
FadeOut(cells), FadeOut(row_lbls), FadeOut(col_lbls),
|
||||||
|
FadeOut(q_label), FadeOut(k_label),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def orth_line(start, end, color=GRAY):
|
def orth_line(start, end, color=GRAY):
|
||||||
"""Create an L-shaped orthogonal line from start to end."""
|
"""Create an L-shaped orthogonal line from start to end."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue