video-promo/transformer.py

405 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""AstrAI promo: Transformer GQA attention animation.
Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow lines:
Input → Q/K/V Projections → Repeat KV → SDPA → O Projection → Output
"""
from manim import *
import numpy as np
class Transformer(Scene):
"""Animates the GQA attention mechanism with orthogonal connection lines."""
def construct(self):
title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE)
title.to_edge(UP, buff=0.35)
self.play(Write(title))
# ── Helper: box ──
def mk(name, color, w=2.6, h=0.72, fs=10):
box = Rectangle(
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
)
lbl = Text(name, font_size=fs, color=color)
return VGroup(box, lbl)
# ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp.move_to(UP * 2.8)
y1 = 1.5
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
q_grp.move_to(LEFT * 3.0 + UP * y1)
k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1)
y2 = 0.0
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp.move_to(UP * y2)
y3 = -1.6
sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·Kᵀ/√d", BLUE, 2.8, 0.74, 10
)
sdpa_grp.move_to(UP * y3)
y4 = -3.0
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp.move_to(UP * y4)
out = Text("x' (hidden states)", font_size=15, color=GRAY)
out.next_to(o_grp, DOWN, buff=0.4)
# ── Animate boxes ──
self.play(Write(inp))
all_boxes = [q_grp, k_grp, v_grp, repeat_grp, sdpa_grp, o_grp]
for g in all_boxes:
self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2)
# ── Input trunk → branch → Q/K/V (enter from directly above) ──
trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0])
trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5)
self.play(Create(trunk), run_time=0.15)
branch_left = Line(
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
color=GRAY, stroke_width=1.5,
)
branch_right = Line(
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
color=GRAY, stroke_width=1.5,
)
self.play(Create(branch_left), Create(branch_right), run_time=0.2)
drop_q = Line(
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
q_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
drop_k = Line(
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
k_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
drop_v = Line(
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
v_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
for ln in [drop_q, drop_k, drop_v]:
self.play(Create(ln), run_time=0.12)
input_lines = VGroup(trunk, branch_left, branch_right, drop_q, drop_k, drop_v)
# ── K/V → Repeat KV (trunk-branch, enter from above) ──
kv_junc_y = repeat_grp.get_top()[1] + 0.3
drop_k2 = Line(
k_grp.get_bottom(),
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
drop_v2 = Line(
v_grp.get_bottom(),
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
kv_branch = Line(
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
kv_trunk = Line(
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
repeat_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
kv_lines = VGroup(drop_k2, drop_v2, kv_branch, kv_trunk)
self.play(Create(kv_lines), run_time=0.3)
# ── Q → SDPA (bypasses Repeat KV, from above) ──
qs_junc_y = sdpa_grp.get_top()[1] + 0.3
line_qs = VMobject(color=GRAY, stroke_width=1.5)
line_qs.set_points_as_corners([
q_grp.get_bottom(),
np.array([q_grp.get_bottom()[0], qs_junc_y, 0]),
np.array([sdpa_grp.get_top()[0], qs_junc_y, 0]),
sdpa_grp.get_top(),
])
self.play(Create(line_qs), run_time=0.15)
line_rs = orth_line(repeat_grp.get_bottom(), sdpa_grp.get_top(), GRAY)
self.play(Create(line_rs), run_time=0.15)
line_so = orth_line(sdpa_grp.get_bottom(), o_grp.get_top(), GRAY)
self.play(Create(line_so), run_time=0.15)
line_oo = orth_line(o_grp.get_bottom(), out.get_top(), GRAY)
self.play(Create(line_oo), run_time=0.15)
self.play(Write(out))
self.wait(0.4)
all_lines = VGroup(
input_lines, kv_lines, line_qs,
line_rs, line_so, line_oo,
)
# ── RoPE highlight ──
rope_q = SurroundingRectangle(q_grp, color=TEAL, buff=0.12)
rope_k = SurroundingRectangle(k_grp, color=TEAL, buff=0.12)
rope_t = Text(
"RoPE: rotary position encoding\napplied to Q and K",
font_size=13, color=TEAL,
)
rope_t.next_to(VGroup(rope_q, rope_k), UP, buff=0.25)
self.play(Create(rope_q), Create(rope_k), Write(rope_t))
self.wait(1.5)
self.play(FadeOut(rope_q), FadeOut(rope_k), FadeOut(rope_t))
# ── GQA ratio highlight ──
gqa_h = SurroundingRectangle(
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
)
gqa_t = Text(
"GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%",
font_size=13, color=YELLOW,
)
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
self.play(Create(gqa_h), Write(gqa_t))
self.wait(1.8)
# ── Repeat KV highlight ──
kv_h = SurroundingRectangle(
VGroup(k_grp, v_grp), color=GREEN, buff=0.12
)
kv_t = Text(
"repeat_kv(): broadcast\n4 heads → 24 heads",
font_size=12, color=GREEN,
)
kv_t.next_to(kv_h, RIGHT, buff=0.5)
self.play(Create(kv_h), Write(kv_t))
self.wait(1.5)
# ── Fade all ──
self.play(
*[FadeOut(g) for g in all_boxes],
FadeOut(all_lines),
FadeOut(kv_h), FadeOut(kv_t),
FadeOut(gqa_h), FadeOut(gqa_t),
FadeOut(inp), FadeOut(out), FadeOut(title),
)
# ── Specs card ──
st = Text("Model Specifications", font_size=36, color=BLUE)
st.to_edge(UP, buff=0.5)
rows_data = [
("Parameters", "~1.0B"),
("Layers", "24 × DecoderBlock"),
("Hidden Dim", "1536"),
("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"),
("Head Dim", "64"),
("FFN Dim", "4608 (SwiGLU)"),
("Max Length", "2048"),
("Precision", "bfloat16"),
]
table = VGroup()
for label, value in rows_data:
row = VGroup(
Text(label + ":", font_size=15, color=GRAY),
Text(value, font_size=15, color=WHITE),
).arrange(RIGHT, buff=0.4, aligned_edge=LEFT)
table.add(row)
table.arrange(DOWN, buff=0.1, aligned_edge=LEFT)
table.next_to(st, DOWN, buff=0.4)
self.play(Write(st), Write(table))
self.wait(2)
self.play(FadeOut(st), FadeOut(table))
# ═══════════════════════════════════════════════════
# 12. Q / K / V — what do they mean?
# ═══════════════════════════════════════════════════
qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE)
qkv_title.to_edge(UP, buff=0.35)
self.play(Write(qkv_title))
self.wait(0.2)
q_txt = Text("Q = Query", font_size=24, color=YELLOW)
k_txt = Text("K = Key", font_size=24, color=ORANGE)
v_txt = Text("V = Value", font_size=24, color=GREEN)
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
qkv_labels.next_to(qkv_title, DOWN, buff=0.6)
self.play(Write(qkv_labels))
q_desc = Text("\"what am I looking for?\"", font_size=12, color=YELLOW) \
.next_to(q_txt, DOWN, buff=0.12)
k_desc = Text("\"what do I have?\"", font_size=12, color=ORANGE) \
.next_to(k_txt, DOWN, buff=0.12)
v_desc = Text("\"what do I contribute?\"", font_size=12, color=GREEN) \
.next_to(v_txt, DOWN, buff=0.12)
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
self.wait(2.0)
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
# ── Full formula ──
full_eq = MathTex(
r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V",
font_size=36, color=WHITE,
)
full_eq.next_to(qkv_title, DOWN, buff=0.6)
self.play(Write(full_eq))
self.wait(1.2)
# ── Step-by-step decomposition ──
steps = [
(MathTex(r"\text{(1) } S = QK^\top", font_size=28, color=YELLOW),
Text("score matrix — pairwise token similarity", font_size=13, color=GRAY)),
(MathTex(r"\text{(2) } S / \sqrt{d_k}", font_size=28, color=ORANGE),
Text("scale — prevents gradient explosion", font_size=13, color=GRAY)),
(MathTex(r"\text{(3) } \operatorname{softmax}(\cdots)", font_size=28, color=GREEN),
Text("normalize — each row sums to 1 (probability)", font_size=13, color=GRAY)),
(MathTex(r"\text{(4) } \cdots \cdot V", font_size=28, color=BLUE),
Text("weighted sum — aggregate values by attention", font_size=13, color=GRAY)),
]
step_group = VGroup()
step_descs = VGroup()
for eq, desc in steps:
sg = VGroup(eq, desc).arrange(DOWN, buff=0.08)
step_group.add(sg)
step_descs.add(desc)
step_group.arrange(DOWN, buff=0.25, aligned_edge=LEFT)
step_group.next_to(full_eq, DOWN, buff=0.7)
self.play(FadeOut(full_eq))
for sg in step_group:
self.play(Write(sg), run_time=0.3)
self.wait(2.0)
self.play(FadeOut(step_group))
# ═══════════════════════════════════════════════════
# 13. Attention score heatmap — concrete example
# ═══════════════════════════════════════════════════
hm_title = Text("Attention Score Heatmap", font_size=34, color=BLUE)
hm_title.to_edge(UP, buff=0.35)
hm_sub = Text("\"The cat sat on the mat\" — causal, per-token attention weights",
font_size=14, color=GRAY).next_to(hm_title, DOWN, buff=0.12)
self.play(FadeOut(qkv_title), Write(hm_title), Write(hm_sub))
tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
n = len(tokens)
cell_size = 0.52
gap = 0.04
grid_high = n * cell_size + (n - 1) * gap
grid_left = -grid_high / 2
grid_top = 1.4
# attention weights (after softmax + causal mask)
weights = [
[1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.05, 0.95, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.02, 0.20, 0.78, 0.00, 0.00, 0.00, 0.00],
[0.01, 0.05, 0.40, 0.54, 0.00, 0.00, 0.00],
[0.00, 0.02, 0.07, 0.35, 0.56, 0.00, 0.00],
[0.00, 0.01, 0.03, 0.10, 0.30, 0.56, 0.00],
[0.00, 0.00, 0.01, 0.05, 0.12, 0.35, 0.47],
]
cells = VGroup()
for i in range(n):
for j in range(n):
w = weights[i][j]
if j > i:
color = DARK_GRAY
fill_op = 0.15
elif w < 0.001:
color = DARKER_GRAY
fill_op = 0.2
else:
color = interpolate_color(BLUE, RED, w)
fill_op = 0.75
sq = Square(
side_length=cell_size, fill_color=color,
fill_opacity=fill_op, stroke_width=0.5,
stroke_color=GRAY,
)
x = grid_left + j * (cell_size + gap) + cell_size / 2
y = grid_top - i * (cell_size + gap) - cell_size / 2
sq.move_to([x, y, 0])
cells.add(sq)
self.play(FadeIn(sq, scale=0.6), run_time=0.015)
# row labels (query) on the left
row_lbls = VGroup()
for i, tok in enumerate(tokens):
lbl = Text(tok, font_size=12, color=GRAY)
y = grid_top - i * (cell_size + gap) - cell_size / 2
lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08)
row_lbls.add(lbl)
q_label = Text("Q", font_size=11, color=WHITE, weight=BOLD)
q_label.move_to(row_lbls[0].get_left() + LEFT * 0.3).shift(UP * 0.15)
self.play(*[Write(l) for l in row_lbls], Write(q_label))
# column labels (key) on top
col_lbls = VGroup()
for j, tok in enumerate(tokens):
lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6)
x = grid_left + j * (cell_size + gap) + cell_size / 2
lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04)
col_lbls.add(lbl)
k_label = Text("K", font_size=11, color=WHITE, weight=BOLD)
k_label.next_to(col_lbls[0], UP, buff=0.06)
self.play(*[Write(l) for l in col_lbls], Write(k_label))
self.wait(1.0)
# causal mask — per-cell red overlay aligned to grid
mask_overlays = VGroup()
for i in range(n):
for j in range(n):
if j > i:
x = grid_left + j * (cell_size + gap) + cell_size / 2
y = grid_top - i * (cell_size + gap) - cell_size / 2
sq = Square(
side_length=cell_size, fill_color=RED,
fill_opacity=0.10, stroke_width=0.5,
stroke_color=RED, stroke_opacity=0.3,
)
sq.move_to([x, y, 0])
mask_overlays.add(sq)
causal_txt = Text("causal mask\n(future tokens hidden)", font_size=11, color=RED) \
.next_to(cells[6], UP, buff=0.25).align_to(cells[6], RIGHT)
self.play(FadeIn(mask_overlays), Write(causal_txt))
self.wait(1.5)
self.play(FadeOut(mask_overlays), FadeOut(causal_txt))
# highlight key patterns
h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04)
h2 = SurroundingRectangle(cells[3 * n + 2], color=ORANGE, stroke_width=2, buff=0.04)
h3 = SurroundingRectangle(cells[4 * n + 3], color=ORANGE, stroke_width=2, buff=0.04)
h4 = SurroundingRectangle(cells[5 * n + 4], color=ORANGE, stroke_width=2, buff=0.04)
h5 = SurroundingRectangle(cells[6 * n + 5], color=ORANGE, stroke_width=2, buff=0.04)
hl_text = Text("previous token attends to next\n(causal sequence learning)", font_size=11, color=ORANGE) \
.next_to(cells[(n - 1) * n + (n - 1)], RIGHT, buff=0.8)
self.play(Create(h1), Create(h2), Create(h3), Create(h4), Create(h5), Write(hl_text))
self.wait(2.0)
self.play(FadeOut(h1), FadeOut(h2), FadeOut(h3), FadeOut(h4), FadeOut(h5), FadeOut(hl_text))
# fade all heatmap
self.play(
FadeOut(hm_title), FadeOut(hm_sub),
FadeOut(cells), FadeOut(row_lbls), FadeOut(col_lbls),
FadeOut(q_label), FadeOut(k_label),
)
def orth_line(start, end, color=GRAY):
"""Create an L-shaped orthogonal line from start to end."""
mid = np.array([start[0], end[1], 0])
path = VMobject(color=color, stroke_width=1.5)
path.set_points_as_corners([start, mid, end])
return path