diff --git a/transformer.py b/transformer.py index 50248f5..8e8289c 100644 --- a/transformer.py +++ b/transformer.py @@ -220,6 +220,181 @@ class Transformer(Scene): self.wait(2) self.play(FadeOut(st), FadeOut(table)) + # ═══════════════════════════════════════════════════ + # 12. Q / K / V — what do they mean? + # ═══════════════════════════════════════════════════ + qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE) + qkv_title.to_edge(UP, buff=0.35) + self.play(Write(qkv_title)) + self.wait(0.2) + + q_txt = Text("Q = Query", font_size=24, color=YELLOW) + k_txt = Text("K = Key", font_size=24, color=ORANGE) + v_txt = Text("V = Value", font_size=24, color=GREEN) + qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5) + qkv_labels.next_to(qkv_title, DOWN, buff=0.6) + self.play(Write(qkv_labels)) + + q_desc = Text("\"what am I looking for?\"", font_size=12, color=YELLOW) \ + .next_to(q_txt, DOWN, buff=0.12) + k_desc = Text("\"what do I have?\"", font_size=12, color=ORANGE) \ + .next_to(k_txt, DOWN, buff=0.12) + v_desc = Text("\"what do I contribute?\"", font_size=12, color=GREEN) \ + .next_to(v_txt, DOWN, buff=0.12) + self.play(Write(q_desc), Write(k_desc), Write(v_desc)) + self.wait(2.0) + self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc)) + + # ── Full formula ── + full_eq = MathTex( + r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V", + font_size=36, color=WHITE, + ) + full_eq.next_to(qkv_title, DOWN, buff=0.6) + self.play(Write(full_eq)) + self.wait(1.2) + + # ── Step-by-step decomposition ── + steps = [ + (MathTex(r"\text{(1) } S = QK^\top", font_size=28, color=YELLOW), + Text("score matrix — pairwise token similarity", font_size=13, color=GRAY)), + (MathTex(r"\text{(2) } S / \sqrt{d_k}", font_size=28, color=ORANGE), + Text("scale — prevents gradient explosion", font_size=13, color=GRAY)), + (MathTex(r"\text{(3) } \operatorname{softmax}(\cdots)", font_size=28, color=GREEN), + Text("normalize — each row sums to 1 (probability)", font_size=13, color=GRAY)), + (MathTex(r"\text{(4) } \cdots \cdot V", font_size=28, color=BLUE), + Text("weighted sum — aggregate values by attention", font_size=13, color=GRAY)), + ] + + step_group = VGroup() + step_descs = VGroup() + for eq, desc in steps: + sg = VGroup(eq, desc).arrange(DOWN, buff=0.08) + step_group.add(sg) + step_descs.add(desc) + step_group.arrange(DOWN, buff=0.25, aligned_edge=LEFT) + step_group.next_to(full_eq, DOWN, buff=0.7) + + self.play(FadeOut(full_eq)) + for sg in step_group: + self.play(Write(sg), run_time=0.3) + self.wait(2.0) + self.play(FadeOut(step_group)) + + # ═══════════════════════════════════════════════════ + # 13. Attention score heatmap — concrete example + # ═══════════════════════════════════════════════════ + hm_title = Text("Attention Score Heatmap", font_size=34, color=BLUE) + hm_title.to_edge(UP, buff=0.35) + hm_sub = Text("\"The cat sat on the mat\" — causal, per-token attention weights", + font_size=14, color=GRAY).next_to(hm_title, DOWN, buff=0.12) + self.play(FadeOut(qkv_title), Write(hm_title), Write(hm_sub)) + + tokens = ["", "The", "cat", "sat", "on", "the", "mat"] + n = len(tokens) + cell_size = 0.52 + gap = 0.04 + grid_high = n * cell_size + (n - 1) * gap + grid_left = -grid_high / 2 + grid_top = 1.4 + + # attention weights (after softmax + causal mask) + weights = [ + [1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00], + [0.05, 0.95, 0.00, 0.00, 0.00, 0.00, 0.00], + [0.02, 0.20, 0.78, 0.00, 0.00, 0.00, 0.00], + [0.01, 0.05, 0.40, 0.54, 0.00, 0.00, 0.00], + [0.00, 0.02, 0.07, 0.35, 0.56, 0.00, 0.00], + [0.00, 0.01, 0.03, 0.10, 0.30, 0.56, 0.00], + [0.00, 0.00, 0.01, 0.05, 0.12, 0.35, 0.47], + ] + + cells = VGroup() + for i in range(n): + for j in range(n): + w = weights[i][j] + if j > i: + color = DARK_GRAY + fill_op = 0.15 + elif w < 0.001: + color = DARKER_GRAY + fill_op = 0.2 + else: + color = interpolate_color(BLUE, RED, w) + fill_op = 0.75 + sq = Square( + side_length=cell_size, fill_color=color, + fill_opacity=fill_op, stroke_width=0.5, + stroke_color=GRAY, + ) + x = grid_left + j * (cell_size + gap) + cell_size / 2 + y = grid_top - i * (cell_size + gap) - cell_size / 2 + sq.move_to([x, y, 0]) + cells.add(sq) + self.play(FadeIn(sq, scale=0.6), run_time=0.015) + + # row labels (query) on the left + row_lbls = VGroup() + for i, tok in enumerate(tokens): + lbl = Text(tok, font_size=12, color=GRAY) + y = grid_top - i * (cell_size + gap) - cell_size / 2 + lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08) + row_lbls.add(lbl) + q_label = Text("Q", font_size=11, color=WHITE, weight=BOLD) + q_label.move_to(row_lbls[0].get_left() + LEFT * 0.3).shift(UP * 0.15) + self.play(*[Write(l) for l in row_lbls], Write(q_label)) + + # column labels (key) on top + col_lbls = VGroup() + for j, tok in enumerate(tokens): + lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6) + x = grid_left + j * (cell_size + gap) + cell_size / 2 + lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04) + col_lbls.add(lbl) + k_label = Text("K", font_size=11, color=WHITE, weight=BOLD) + k_label.next_to(col_lbls[0], UP, buff=0.06) + self.play(*[Write(l) for l in col_lbls], Write(k_label)) + self.wait(1.0) + + # causal mask — per-cell red overlay aligned to grid + mask_overlays = VGroup() + for i in range(n): + for j in range(n): + if j > i: + x = grid_left + j * (cell_size + gap) + cell_size / 2 + y = grid_top - i * (cell_size + gap) - cell_size / 2 + sq = Square( + side_length=cell_size, fill_color=RED, + fill_opacity=0.10, stroke_width=0.5, + stroke_color=RED, stroke_opacity=0.3, + ) + sq.move_to([x, y, 0]) + mask_overlays.add(sq) + causal_txt = Text("causal mask\n(future tokens hidden)", font_size=11, color=RED) \ + .next_to(cells[6], UP, buff=0.25).align_to(cells[6], RIGHT) + self.play(FadeIn(mask_overlays), Write(causal_txt)) + self.wait(1.5) + self.play(FadeOut(mask_overlays), FadeOut(causal_txt)) + + # highlight key patterns + h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04) + h2 = SurroundingRectangle(cells[3 * n + 2], color=ORANGE, stroke_width=2, buff=0.04) + h3 = SurroundingRectangle(cells[4 * n + 3], color=ORANGE, stroke_width=2, buff=0.04) + h4 = SurroundingRectangle(cells[5 * n + 4], color=ORANGE, stroke_width=2, buff=0.04) + h5 = SurroundingRectangle(cells[6 * n + 5], color=ORANGE, stroke_width=2, buff=0.04) + hl_text = Text("previous token attends to next\n(causal sequence learning)", font_size=11, color=ORANGE) \ + .next_to(cells[(n - 1) * n + (n - 1)], RIGHT, buff=0.8) + self.play(Create(h1), Create(h2), Create(h3), Create(h4), Create(h5), Write(hl_text)) + self.wait(2.0) + self.play(FadeOut(h1), FadeOut(h2), FadeOut(h3), FadeOut(h4), FadeOut(h5), FadeOut(hl_text)) + + # fade all heatmap + self.play( + FadeOut(hm_title), FadeOut(hm_sub), + FadeOut(cells), FadeOut(row_lbls), FadeOut(col_lbls), + FadeOut(q_label), FadeOut(k_label), + ) + def orth_line(start, end, color=GRAY): """Create an L-shaped orthogonal line from start to end."""