"""AstrAI promo: Transformer GQA attention animation. Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow lines: Input → Q/K/V Projections → Repeat KV → SDPA → O Projection → Output """ from manim import * import numpy as np class Transformer(Scene): """Animates the GQA attention mechanism with orthogonal connection lines.""" def construct(self): title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE) title.to_edge(UP, buff=0.35) self.play(Write(title)) # ── Helper: box ── def mk(name, color, w=2.6, h=0.72, fs=10): box = Rectangle( width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5 ) lbl = Text(name, font_size=fs, color=color) return VGroup(box, lbl) # ── Layout ── inp = Text("x (hidden states)", font_size=15, color=GRAY) inp.move_to(UP * 2.8) y1 = 1.5 q_grp = mk("Q Projection\n1536 → 24×64", YELLOW) k_grp = mk("K Projection\n1536 → 4×64", YELLOW) v_grp = mk("V Projection\n1536 → 4×64", YELLOW) q_grp.move_to(LEFT * 3.0 + UP * y1) k_grp.move_to(UP * y1) v_grp.move_to(RIGHT * 3.0 + UP * y1) y2 = 0.0 repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10) repeat_grp.move_to(UP * y2) y3 = -1.6 sdpa_grp = mk( "Scaled Dot-Product\nAttention Q·Kᵀ/√d", BLUE, 2.8, 0.74, 10 ) sdpa_grp.move_to(UP * y3) y4 = -3.0 o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10) o_grp.move_to(UP * y4) out = Text("x' (hidden states)", font_size=15, color=GRAY) out.next_to(o_grp, DOWN, buff=0.4) # ── Animate boxes ── self.play(Write(inp)) all_boxes = [q_grp, k_grp, v_grp, repeat_grp, sdpa_grp, o_grp] for g in all_boxes: self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2) # ── Input trunk → branch → Q/K/V (enter from directly above) ── trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0]) trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5) self.play(Create(trunk), run_time=0.15) branch_left = Line( np.array([q_grp.get_top()[0], trunk_bottom[1], 0]), np.array([k_grp.get_top()[0], trunk_bottom[1], 0]), color=GRAY, stroke_width=1.5, ) branch_right = Line( np.array([k_grp.get_top()[0], trunk_bottom[1], 0]), np.array([v_grp.get_top()[0], trunk_bottom[1], 0]), color=GRAY, stroke_width=1.5, ) self.play(Create(branch_left), Create(branch_right), run_time=0.2) drop_q = Line( np.array([q_grp.get_top()[0], trunk_bottom[1], 0]), q_grp.get_top(), color=GRAY, stroke_width=1.5, ) drop_k = Line( np.array([k_grp.get_top()[0], trunk_bottom[1], 0]), k_grp.get_top(), color=GRAY, stroke_width=1.5, ) drop_v = Line( np.array([v_grp.get_top()[0], trunk_bottom[1], 0]), v_grp.get_top(), color=GRAY, stroke_width=1.5, ) for ln in [drop_q, drop_k, drop_v]: self.play(Create(ln), run_time=0.12) input_lines = VGroup(trunk, branch_left, branch_right, drop_q, drop_k, drop_v) # ── K/V → Repeat KV (trunk-branch, enter from above) ── kv_junc_y = repeat_grp.get_top()[1] + 0.3 drop_k2 = Line( k_grp.get_bottom(), np.array([k_grp.get_bottom()[0], kv_junc_y, 0]), color=GRAY, stroke_width=1.5, ) drop_v2 = Line( v_grp.get_bottom(), np.array([v_grp.get_bottom()[0], kv_junc_y, 0]), color=GRAY, stroke_width=1.5, ) kv_branch = Line( np.array([v_grp.get_bottom()[0], kv_junc_y, 0]), np.array([k_grp.get_bottom()[0], kv_junc_y, 0]), color=GRAY, stroke_width=1.5, ) kv_trunk = Line( np.array([k_grp.get_bottom()[0], kv_junc_y, 0]), repeat_grp.get_top(), color=GRAY, stroke_width=1.5, ) kv_lines = VGroup(drop_k2, drop_v2, kv_branch, kv_trunk) self.play(Create(kv_lines), run_time=0.3) # ── Q → SDPA (bypasses Repeat KV, from above) ── qs_junc_y = sdpa_grp.get_top()[1] + 0.3 line_qs = VMobject(color=GRAY, stroke_width=1.5) line_qs.set_points_as_corners([ q_grp.get_bottom(), np.array([q_grp.get_bottom()[0], qs_junc_y, 0]), np.array([sdpa_grp.get_top()[0], qs_junc_y, 0]), sdpa_grp.get_top(), ]) self.play(Create(line_qs), run_time=0.15) line_rs = orth_line(repeat_grp.get_bottom(), sdpa_grp.get_top(), GRAY) self.play(Create(line_rs), run_time=0.15) line_so = orth_line(sdpa_grp.get_bottom(), o_grp.get_top(), GRAY) self.play(Create(line_so), run_time=0.15) line_oo = orth_line(o_grp.get_bottom(), out.get_top(), GRAY) self.play(Create(line_oo), run_time=0.15) self.play(Write(out)) self.wait(0.4) all_lines = VGroup( input_lines, kv_lines, line_qs, line_rs, line_so, line_oo, ) # ── RoPE highlight ── rope_q = SurroundingRectangle(q_grp, color=TEAL, buff=0.12) rope_k = SurroundingRectangle(k_grp, color=TEAL, buff=0.12) rope_t = Text( "RoPE: rotary position encoding\napplied to Q and K", font_size=13, color=TEAL, ) rope_t.next_to(VGroup(rope_q, rope_k), UP, buff=0.25) self.play(Create(rope_q), Create(rope_k), Write(rope_t)) self.wait(1.5) self.play(FadeOut(rope_q), FadeOut(rope_k), FadeOut(rope_t)) # ── GQA ratio highlight ── gqa_h = SurroundingRectangle( VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2 ) gqa_t = Text( "GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%", font_size=13, color=YELLOW, ) gqa_t.next_to(gqa_h, RIGHT, buff=0.5) self.play(Create(gqa_h), Write(gqa_t)) self.wait(1.8) # ── Repeat KV highlight ── kv_h = SurroundingRectangle( VGroup(k_grp, v_grp), color=GREEN, buff=0.12 ) kv_t = Text( "repeat_kv(): broadcast\n4 heads → 24 heads", font_size=12, color=GREEN, ) kv_t.next_to(kv_h, RIGHT, buff=0.5) self.play(Create(kv_h), Write(kv_t)) self.wait(1.5) # ── Fade all ── self.play( *[FadeOut(g) for g in all_boxes], FadeOut(all_lines), FadeOut(kv_h), FadeOut(kv_t), FadeOut(gqa_h), FadeOut(gqa_t), FadeOut(inp), FadeOut(out), FadeOut(title), ) # ── Specs card ── st = Text("Model Specifications", font_size=36, color=BLUE) st.to_edge(UP, buff=0.5) rows_data = [ ("Parameters", "~1.0B"), ("Layers", "24 × DecoderBlock"), ("Hidden Dim", "1536"), ("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"), ("Head Dim", "64"), ("FFN Dim", "4608 (SwiGLU)"), ("Max Length", "2048"), ("Precision", "bfloat16"), ] table = VGroup() for label, value in rows_data: row = VGroup( Text(label + ":", font_size=15, color=GRAY), Text(value, font_size=15, color=WHITE), ).arrange(RIGHT, buff=0.4, aligned_edge=LEFT) table.add(row) table.arrange(DOWN, buff=0.1, aligned_edge=LEFT) table.next_to(st, DOWN, buff=0.4) self.play(Write(st), Write(table)) self.wait(2) self.play(FadeOut(st), FadeOut(table)) def orth_line(start, end, color=GRAY): """Create an L-shaped orthogonal line from start to end.""" mid = np.array([start[0], end[1], 0]) path = VMobject(color=color, stroke_width=1.5) path.set_points_as_corners([start, mid, end]) return path