230 lines
8.2 KiB
Python
230 lines
8.2 KiB
Python
"""AstrAI promo: Transformer GQA attention animation.
|
||
|
||
Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow lines:
|
||
Input → Q/K/V Projections → Repeat KV → SDPA → O Projection → Output
|
||
"""
|
||
|
||
from manim import *
|
||
import numpy as np
|
||
|
||
|
||
class Transformer(Scene):
|
||
"""Animates the GQA attention mechanism with orthogonal connection lines."""
|
||
|
||
def construct(self):
|
||
title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE)
|
||
title.to_edge(UP, buff=0.35)
|
||
self.play(Write(title))
|
||
|
||
# ── Helper: box ──
|
||
def mk(name, color, w=2.6, h=0.72, fs=10):
|
||
box = Rectangle(
|
||
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
|
||
)
|
||
lbl = Text(name, font_size=fs, color=color)
|
||
return VGroup(box, lbl)
|
||
|
||
# ── Layout ──
|
||
inp = Text("x (hidden states)", font_size=15, color=GRAY)
|
||
inp.move_to(UP * 2.8)
|
||
|
||
y1 = 1.5
|
||
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
|
||
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
|
||
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
|
||
q_grp.move_to(LEFT * 3.0 + UP * y1)
|
||
k_grp.move_to(UP * y1)
|
||
v_grp.move_to(RIGHT * 3.0 + UP * y1)
|
||
|
||
y2 = 0.0
|
||
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
|
||
repeat_grp.move_to(UP * y2)
|
||
|
||
y3 = -1.6
|
||
sdpa_grp = mk(
|
||
"Scaled Dot-Product\nAttention Q·Kᵀ/√d", BLUE, 2.8, 0.74, 10
|
||
)
|
||
sdpa_grp.move_to(UP * y3)
|
||
|
||
y4 = -3.0
|
||
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
|
||
o_grp.move_to(UP * y4)
|
||
|
||
out = Text("x' (hidden states)", font_size=15, color=GRAY)
|
||
out.next_to(o_grp, DOWN, buff=0.4)
|
||
|
||
# ── Animate boxes ──
|
||
self.play(Write(inp))
|
||
all_boxes = [q_grp, k_grp, v_grp, repeat_grp, sdpa_grp, o_grp]
|
||
for g in all_boxes:
|
||
self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2)
|
||
|
||
# ── Input trunk → branch → Q/K/V (enter from directly above) ──
|
||
trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0])
|
||
trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5)
|
||
self.play(Create(trunk), run_time=0.15)
|
||
|
||
branch_left = Line(
|
||
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
|
||
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
branch_right = Line(
|
||
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
|
||
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
self.play(Create(branch_left), Create(branch_right), run_time=0.2)
|
||
|
||
drop_q = Line(
|
||
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
|
||
q_grp.get_top(),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
drop_k = Line(
|
||
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
|
||
k_grp.get_top(),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
drop_v = Line(
|
||
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
|
||
v_grp.get_top(),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
for ln in [drop_q, drop_k, drop_v]:
|
||
self.play(Create(ln), run_time=0.12)
|
||
|
||
input_lines = VGroup(trunk, branch_left, branch_right, drop_q, drop_k, drop_v)
|
||
|
||
# ── K/V → Repeat KV (trunk-branch, enter from above) ──
|
||
kv_junc_y = repeat_grp.get_top()[1] + 0.3
|
||
drop_k2 = Line(
|
||
k_grp.get_bottom(),
|
||
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
drop_v2 = Line(
|
||
v_grp.get_bottom(),
|
||
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
kv_branch = Line(
|
||
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
|
||
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
kv_trunk = Line(
|
||
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
|
||
repeat_grp.get_top(),
|
||
color=GRAY, stroke_width=1.5,
|
||
)
|
||
kv_lines = VGroup(drop_k2, drop_v2, kv_branch, kv_trunk)
|
||
self.play(Create(kv_lines), run_time=0.3)
|
||
|
||
# ── Q → SDPA (bypasses Repeat KV, from above) ──
|
||
qs_junc_y = sdpa_grp.get_top()[1] + 0.3
|
||
line_qs = VMobject(color=GRAY, stroke_width=1.5)
|
||
line_qs.set_points_as_corners([
|
||
q_grp.get_bottom(),
|
||
np.array([q_grp.get_bottom()[0], qs_junc_y, 0]),
|
||
np.array([sdpa_grp.get_top()[0], qs_junc_y, 0]),
|
||
sdpa_grp.get_top(),
|
||
])
|
||
self.play(Create(line_qs), run_time=0.15)
|
||
|
||
line_rs = orth_line(repeat_grp.get_bottom(), sdpa_grp.get_top(), GRAY)
|
||
self.play(Create(line_rs), run_time=0.15)
|
||
|
||
line_so = orth_line(sdpa_grp.get_bottom(), o_grp.get_top(), GRAY)
|
||
self.play(Create(line_so), run_time=0.15)
|
||
|
||
line_oo = orth_line(o_grp.get_bottom(), out.get_top(), GRAY)
|
||
self.play(Create(line_oo), run_time=0.15)
|
||
self.play(Write(out))
|
||
|
||
self.wait(0.4)
|
||
|
||
all_lines = VGroup(
|
||
input_lines, kv_lines, line_qs,
|
||
line_rs, line_so, line_oo,
|
||
)
|
||
|
||
# ── RoPE highlight ──
|
||
rope_q = SurroundingRectangle(q_grp, color=TEAL, buff=0.12)
|
||
rope_k = SurroundingRectangle(k_grp, color=TEAL, buff=0.12)
|
||
rope_t = Text(
|
||
"RoPE: rotary position encoding\napplied to Q and K",
|
||
font_size=13, color=TEAL,
|
||
)
|
||
rope_t.next_to(VGroup(rope_q, rope_k), UP, buff=0.25)
|
||
self.play(Create(rope_q), Create(rope_k), Write(rope_t))
|
||
self.wait(1.5)
|
||
self.play(FadeOut(rope_q), FadeOut(rope_k), FadeOut(rope_t))
|
||
|
||
# ── GQA ratio highlight ──
|
||
gqa_h = SurroundingRectangle(
|
||
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
|
||
)
|
||
gqa_t = Text(
|
||
"GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%",
|
||
font_size=13, color=YELLOW,
|
||
)
|
||
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
|
||
self.play(Create(gqa_h), Write(gqa_t))
|
||
self.wait(1.8)
|
||
|
||
# ── Repeat KV highlight ──
|
||
kv_h = SurroundingRectangle(
|
||
VGroup(k_grp, v_grp), color=GREEN, buff=0.12
|
||
)
|
||
kv_t = Text(
|
||
"repeat_kv(): broadcast\n4 heads → 24 heads",
|
||
font_size=12, color=GREEN,
|
||
)
|
||
kv_t.next_to(kv_h, RIGHT, buff=0.5)
|
||
self.play(Create(kv_h), Write(kv_t))
|
||
self.wait(1.5)
|
||
|
||
# ── Fade all ──
|
||
self.play(
|
||
*[FadeOut(g) for g in all_boxes],
|
||
FadeOut(all_lines),
|
||
FadeOut(kv_h), FadeOut(kv_t),
|
||
FadeOut(gqa_h), FadeOut(gqa_t),
|
||
FadeOut(inp), FadeOut(out), FadeOut(title),
|
||
)
|
||
|
||
# ── Specs card ──
|
||
st = Text("Model Specifications", font_size=36, color=BLUE)
|
||
st.to_edge(UP, buff=0.5)
|
||
rows_data = [
|
||
("Parameters", "~1.0B"),
|
||
("Layers", "24 × DecoderBlock"),
|
||
("Hidden Dim", "1536"),
|
||
("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"),
|
||
("Head Dim", "64"),
|
||
("FFN Dim", "4608 (SwiGLU)"),
|
||
("Max Length", "2048"),
|
||
("Precision", "bfloat16"),
|
||
]
|
||
table = VGroup()
|
||
for label, value in rows_data:
|
||
row = VGroup(
|
||
Text(label + ":", font_size=15, color=GRAY),
|
||
Text(value, font_size=15, color=WHITE),
|
||
).arrange(RIGHT, buff=0.4, aligned_edge=LEFT)
|
||
table.add(row)
|
||
table.arrange(DOWN, buff=0.1, aligned_edge=LEFT)
|
||
table.next_to(st, DOWN, buff=0.4)
|
||
self.play(Write(st), Write(table))
|
||
self.wait(2)
|
||
self.play(FadeOut(st), FadeOut(table))
|
||
|
||
|
||
def orth_line(start, end, color=GRAY):
|
||
"""Create an L-shaped orthogonal line from start to end."""
|
||
mid = np.array([start[0], end[1], 0])
|
||
path = VMobject(color=color, stroke_width=1.5)
|
||
path.set_points_as_corners([start, mid, end])
|
||
return path
|