diff --git a/transformer.py b/transformer.py index 8e8289c..c1df84c 100644 --- a/transformer.py +++ b/transformer.py @@ -26,9 +26,9 @@ class Transformer(Scene): # ── Layout ── inp = Text("x (hidden states)", font_size=15, color=GRAY) - inp.move_to(UP * 2.8) + inp.move_to(UP * 3.2) - y1 = 1.5 + y1 = 1.9 q_grp = mk("Q Projection\n1536 → 24×64", YELLOW) k_grp = mk("K Projection\n1536 → 4×64", YELLOW) v_grp = mk("V Projection\n1536 → 4×64", YELLOW) @@ -36,17 +36,17 @@ class Transformer(Scene): k_grp.move_to(UP * y1) v_grp.move_to(RIGHT * 3.0 + UP * y1) - y2 = 0.0 + y2 = 0.4 repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10) repeat_grp.move_to(UP * y2) - y3 = -1.6 + y3 = -1.2 sdpa_grp = mk( - "Scaled Dot-Product\nAttention Q·Kᵀ/√d", BLUE, 2.8, 0.74, 10 + "Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10 ) sdpa_grp.move_to(UP * y3) - y4 = -3.0 + y4 = -2.6 o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10) o_grp.move_to(UP * y4) @@ -172,6 +172,7 @@ class Transformer(Scene): gqa_t.next_to(gqa_h, RIGHT, buff=0.5) self.play(Create(gqa_h), Write(gqa_t)) self.wait(1.8) + self.play(FadeOut(gqa_h), FadeOut(gqa_t)) # ── Repeat KV highlight ── kv_h = SurroundingRectangle( @@ -190,96 +191,68 @@ class Transformer(Scene): *[FadeOut(g) for g in all_boxes], FadeOut(all_lines), FadeOut(kv_h), FadeOut(kv_t), - FadeOut(gqa_h), FadeOut(gqa_t), FadeOut(inp), FadeOut(out), FadeOut(title), ) - # ── Specs card ── - st = Text("Model Specifications", font_size=36, color=BLUE) - st.to_edge(UP, buff=0.5) - rows_data = [ - ("Parameters", "~1.0B"), - ("Layers", "24 × DecoderBlock"), - ("Hidden Dim", "1536"), - ("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"), - ("Head Dim", "64"), - ("FFN Dim", "4608 (SwiGLU)"), - ("Max Length", "2048"), - ("Precision", "bfloat16"), - ] - table = VGroup() - for label, value in rows_data: - row = VGroup( - Text(label + ":", font_size=15, color=GRAY), - Text(value, font_size=15, color=WHITE), - ).arrange(RIGHT, buff=0.4, aligned_edge=LEFT) - table.add(row) - table.arrange(DOWN, buff=0.1, aligned_edge=LEFT) - table.next_to(st, DOWN, buff=0.4) - self.play(Write(st), Write(table)) - self.wait(2) - self.play(FadeOut(st), FadeOut(table)) - # ═══════════════════════════════════════════════════ - # 12. Q / K / V — what do they mean? + # 12. Scaled Dot-Product Attention — full formula + breakdown # ═══════════════════════════════════════════════════ qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE) qkv_title.to_edge(UP, buff=0.35) self.play(Write(qkv_title)) - self.wait(0.2) - q_txt = Text("Q = Query", font_size=24, color=YELLOW) - k_txt = Text("K = Key", font_size=24, color=ORANGE) - v_txt = Text("V = Value", font_size=24, color=GREEN) - qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5) - qkv_labels.next_to(qkv_title, DOWN, buff=0.6) - self.play(Write(qkv_labels)) - - q_desc = Text("\"what am I looking for?\"", font_size=12, color=YELLOW) \ - .next_to(q_txt, DOWN, buff=0.12) - k_desc = Text("\"what do I have?\"", font_size=12, color=ORANGE) \ - .next_to(k_txt, DOWN, buff=0.12) - v_desc = Text("\"what do I contribute?\"", font_size=12, color=GREEN) \ - .next_to(v_txt, DOWN, buff=0.12) - self.play(Write(q_desc), Write(k_desc), Write(v_desc)) - self.wait(2.0) - self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc)) - - # ── Full formula ── + # Full formula first, stays on screen full_eq = MathTex( r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V", font_size=36, color=WHITE, ) - full_eq.next_to(qkv_title, DOWN, buff=0.6) + full_eq.next_to(qkv_title, DOWN, buff=0.5) self.play(Write(full_eq)) - self.wait(1.2) + self.wait(0.6) - # ── Step-by-step decomposition ── + # Q / K / V — brief meanings + q_txt = Text("Q = Query", font_size=18, color=YELLOW) + k_txt = Text("K = Key", font_size=18, color=ORANGE) + v_txt = Text("V = Value", font_size=18, color=GREEN) + qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5) + qkv_labels.next_to(full_eq, DOWN, buff=0.5) + self.play(Write(qkv_labels)) + + q_desc = Text("\"what am I looking for?\"", font_size=11, color=YELLOW) \ + .next_to(q_txt, DOWN, buff=0.10) + k_desc = Text("\"what do I have?\"", font_size=11, color=ORANGE) \ + .next_to(k_txt, DOWN, buff=0.10) + v_desc = Text("\"what do I contribute?\"", font_size=11, color=GREEN) \ + .next_to(v_txt, DOWN, buff=0.10) + self.play(Write(q_desc), Write(k_desc), Write(v_desc)) + self.wait(2.0) + self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc)) + + # Step-by-step decomposition — full formula stays visible above steps = [ - (MathTex(r"\text{(1) } S = QK^\top", font_size=28, color=YELLOW), - Text("score matrix — pairwise token similarity", font_size=13, color=GRAY)), - (MathTex(r"\text{(2) } S / \sqrt{d_k}", font_size=28, color=ORANGE), - Text("scale — prevents gradient explosion", font_size=13, color=GRAY)), - (MathTex(r"\text{(3) } \operatorname{softmax}(\cdots)", font_size=28, color=GREEN), - Text("normalize — each row sums to 1 (probability)", font_size=13, color=GRAY)), - (MathTex(r"\text{(4) } \cdots \cdot V", font_size=28, color=BLUE), - Text("weighted sum — aggregate values by attention", font_size=13, color=GRAY)), + (MathTex(r"\text{(1) } S = QK^\top", font_size=26, color=YELLOW), + Text("score matrix — pairwise token similarity", font_size=12, color=GRAY)), + (MathTex(r"\text{(2) } S \mathbin{/} \sqrt{d_k}", font_size=26, color=ORANGE), + Text("scale — prevents gradient explosion", font_size=12, color=GRAY)), + (MathTex(r"\text{(3) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right)", font_size=26, color=GREEN), + Text("normalize — each row sums to 1 (probability)", font_size=12, color=GRAY)), + (MathTex(r"\text{(4) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right) \cdot V", font_size=26, color=BLUE), + Text("weighted sum — aggregate values by attention", font_size=12, color=GRAY)), ] step_group = VGroup() - step_descs = VGroup() + steps_mobj = VGroup() for eq, desc in steps: - sg = VGroup(eq, desc).arrange(DOWN, buff=0.08) + sg = VGroup(eq, desc).arrange(DOWN, buff=0.06) step_group.add(sg) - step_descs.add(desc) - step_group.arrange(DOWN, buff=0.25, aligned_edge=LEFT) - step_group.next_to(full_eq, DOWN, buff=0.7) + steps_mobj.add(eq) + step_group.arrange(DOWN, buff=0.22, aligned_edge=LEFT) + step_group.next_to(full_eq, DOWN, buff=0.6) - self.play(FadeOut(full_eq)) for sg in step_group: self.play(Write(sg), run_time=0.3) - self.wait(2.0) - self.play(FadeOut(step_group)) + self.wait(2.5) + self.play(FadeOut(step_group), FadeOut(full_eq)) # ═══════════════════════════════════════════════════ # 13. Attention score heatmap — concrete example