video-promo/transformer.py

230 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""AstrAI promo: Transformer GQA attention animation.
Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow lines:
Input → Q/K/V Projections → Repeat KV → SDPA → O Projection → Output
"""
from manim import *
import numpy as np
class Transformer(Scene):
"""Animates the GQA attention mechanism with orthogonal connection lines."""
def construct(self):
title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE)
title.to_edge(UP, buff=0.35)
self.play(Write(title))
# ── Helper: box ──
def mk(name, color, w=2.6, h=0.72, fs=10):
box = Rectangle(
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
)
lbl = Text(name, font_size=fs, color=color)
return VGroup(box, lbl)
# ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp.move_to(UP * 2.8)
y1 = 1.5
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
q_grp.move_to(LEFT * 3.0 + UP * y1)
k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1)
y2 = 0.0
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp.move_to(UP * y2)
y3 = -1.6
sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·Kᵀ/√d", BLUE, 2.8, 0.74, 10
)
sdpa_grp.move_to(UP * y3)
y4 = -3.0
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp.move_to(UP * y4)
out = Text("x' (hidden states)", font_size=15, color=GRAY)
out.next_to(o_grp, DOWN, buff=0.4)
# ── Animate boxes ──
self.play(Write(inp))
all_boxes = [q_grp, k_grp, v_grp, repeat_grp, sdpa_grp, o_grp]
for g in all_boxes:
self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2)
# ── Input trunk → branch → Q/K/V (enter from directly above) ──
trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0])
trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5)
self.play(Create(trunk), run_time=0.15)
branch_left = Line(
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
color=GRAY, stroke_width=1.5,
)
branch_right = Line(
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
color=GRAY, stroke_width=1.5,
)
self.play(Create(branch_left), Create(branch_right), run_time=0.2)
drop_q = Line(
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
q_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
drop_k = Line(
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
k_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
drop_v = Line(
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
v_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
for ln in [drop_q, drop_k, drop_v]:
self.play(Create(ln), run_time=0.12)
input_lines = VGroup(trunk, branch_left, branch_right, drop_q, drop_k, drop_v)
# ── K/V → Repeat KV (trunk-branch, enter from above) ──
kv_junc_y = repeat_grp.get_top()[1] + 0.3
drop_k2 = Line(
k_grp.get_bottom(),
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
drop_v2 = Line(
v_grp.get_bottom(),
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
kv_branch = Line(
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
kv_trunk = Line(
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
repeat_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
kv_lines = VGroup(drop_k2, drop_v2, kv_branch, kv_trunk)
self.play(Create(kv_lines), run_time=0.3)
# ── Q → SDPA (bypasses Repeat KV, from above) ──
qs_junc_y = sdpa_grp.get_top()[1] + 0.3
line_qs = VMobject(color=GRAY, stroke_width=1.5)
line_qs.set_points_as_corners([
q_grp.get_bottom(),
np.array([q_grp.get_bottom()[0], qs_junc_y, 0]),
np.array([sdpa_grp.get_top()[0], qs_junc_y, 0]),
sdpa_grp.get_top(),
])
self.play(Create(line_qs), run_time=0.15)
line_rs = orth_line(repeat_grp.get_bottom(), sdpa_grp.get_top(), GRAY)
self.play(Create(line_rs), run_time=0.15)
line_so = orth_line(sdpa_grp.get_bottom(), o_grp.get_top(), GRAY)
self.play(Create(line_so), run_time=0.15)
line_oo = orth_line(o_grp.get_bottom(), out.get_top(), GRAY)
self.play(Create(line_oo), run_time=0.15)
self.play(Write(out))
self.wait(0.4)
all_lines = VGroup(
input_lines, kv_lines, line_qs,
line_rs, line_so, line_oo,
)
# ── RoPE highlight ──
rope_q = SurroundingRectangle(q_grp, color=TEAL, buff=0.12)
rope_k = SurroundingRectangle(k_grp, color=TEAL, buff=0.12)
rope_t = Text(
"RoPE: rotary position encoding\napplied to Q and K",
font_size=13, color=TEAL,
)
rope_t.next_to(VGroup(rope_q, rope_k), UP, buff=0.25)
self.play(Create(rope_q), Create(rope_k), Write(rope_t))
self.wait(1.5)
self.play(FadeOut(rope_q), FadeOut(rope_k), FadeOut(rope_t))
# ── GQA ratio highlight ──
gqa_h = SurroundingRectangle(
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
)
gqa_t = Text(
"GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%",
font_size=13, color=YELLOW,
)
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
self.play(Create(gqa_h), Write(gqa_t))
self.wait(1.8)
# ── Repeat KV highlight ──
kv_h = SurroundingRectangle(
VGroup(k_grp, v_grp), color=GREEN, buff=0.12
)
kv_t = Text(
"repeat_kv(): broadcast\n4 heads → 24 heads",
font_size=12, color=GREEN,
)
kv_t.next_to(kv_h, RIGHT, buff=0.5)
self.play(Create(kv_h), Write(kv_t))
self.wait(1.5)
# ── Fade all ──
self.play(
*[FadeOut(g) for g in all_boxes],
FadeOut(all_lines),
FadeOut(kv_h), FadeOut(kv_t),
FadeOut(gqa_h), FadeOut(gqa_t),
FadeOut(inp), FadeOut(out), FadeOut(title),
)
# ── Specs card ──
st = Text("Model Specifications", font_size=36, color=BLUE)
st.to_edge(UP, buff=0.5)
rows_data = [
("Parameters", "~1.0B"),
("Layers", "24 × DecoderBlock"),
("Hidden Dim", "1536"),
("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"),
("Head Dim", "64"),
("FFN Dim", "4608 (SwiGLU)"),
("Max Length", "2048"),
("Precision", "bfloat16"),
]
table = VGroup()
for label, value in rows_data:
row = VGroup(
Text(label + ":", font_size=15, color=GRAY),
Text(value, font_size=15, color=WHITE),
).arrange(RIGHT, buff=0.4, aligned_edge=LEFT)
table.add(row)
table.arrange(DOWN, buff=0.1, aligned_edge=LEFT)
table.next_to(st, DOWN, buff=0.4)
self.play(Write(st), Write(table))
self.wait(2)
self.play(FadeOut(st), FadeOut(table))
def orth_line(start, end, color=GRAY):
"""Create an L-shaped orthogonal line from start to end."""
mid = np.array([start[0], end[1], 0])
path = VMobject(color=color, stroke_width=1.5)
path.set_points_as_corners([start, mid, end])
return path