diff --git a/transformer.py b/transformer.py index c1df84c..1081008 100644 --- a/transformer.py +++ b/transformer.py @@ -6,6 +6,7 @@ Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow line from manim import * import numpy as np +import math class Transformer(Scene): @@ -271,39 +272,48 @@ class Transformer(Scene): grid_left = -grid_high / 2 grid_top = 1.4 - # attention weights (after softmax + causal mask) - weights = [ - [1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00], - [0.05, 0.95, 0.00, 0.00, 0.00, 0.00, 0.00], - [0.02, 0.20, 0.78, 0.00, 0.00, 0.00, 0.00], - [0.01, 0.05, 0.40, 0.54, 0.00, 0.00, 0.00], - [0.00, 0.02, 0.07, 0.35, 0.56, 0.00, 0.00], - [0.00, 0.01, 0.03, 0.10, 0.30, 0.56, 0.00], - [0.00, 0.00, 0.01, 0.05, 0.12, 0.35, 0.47], + # pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased + pre_scores = [ + [2.8, 1.5, 0.3, 0.1, 0.0, 0.0, 0.0], + [1.2, 3.5, 1.8, 0.5, 0.2, 0.1, 0.0], + [0.4, 2.0, 3.0, 1.5, 0.6, 0.2, 0.1], + [0.1, 0.6, 2.5, 2.8, 1.2, 0.4, 0.1], + [0.0, 0.2, 0.8, 2.0, 2.5, 1.5, 0.3], + [0.0, 0.1, 0.3, 0.9, 1.8, 2.5, 1.2], + [0.0, 0.0, 0.1, 0.4, 0.8, 1.5, 3.0], ] + # compute post-softmax weights with causal mask (j > i → -inf) + post_weights = [] + for i in range(n): + row = pre_scores[i] + masked = [-float('inf') if j > i else row[j] for j in range(n)] + exps = [math.exp(v) for v in masked] + exp_sum = sum(exps) + post_weights.append([e / exp_sum for e in exps]) + + flat_pre = [w for row in pre_scores for w in row] + pre_min, pre_max = min(flat_pre), max(flat_pre) + flat_post = [w for row in post_weights for w in row] + post_min, post_max = min(flat_post), max(flat_post) cells = VGroup() + masked_cells = VGroup() for i in range(n): for j in range(n): - w = weights[i][j] - if j > i: - color = DARK_GRAY - fill_op = 0.15 - elif w < 0.001: - color = DARKER_GRAY - fill_op = 0.2 - else: - color = interpolate_color(BLUE, RED, w) - fill_op = 0.75 + pw = pre_scores[i][j] + pw_normed = (pw - pre_min) / (pre_max - pre_min) + pw_color = interpolate_color(BLUE, RED, pw_normed) sq = Square( - side_length=cell_size, fill_color=color, - fill_opacity=fill_op, stroke_width=0.5, + side_length=cell_size, fill_color=pw_color, + fill_opacity=0.75, stroke_width=0.5, stroke_color=GRAY, ) x = grid_left + j * (cell_size + gap) + cell_size / 2 y = grid_top - i * (cell_size + gap) - cell_size / 2 sq.move_to([x, y, 0]) cells.add(sq) + if j > i: + masked_cells.add(sq) self.play(FadeIn(sq, scale=0.6), run_time=0.015) # row labels (query) on the left @@ -329,25 +339,21 @@ class Transformer(Scene): self.play(*[Write(l) for l in col_lbls], Write(k_label)) self.wait(1.0) - # causal mask — per-cell red overlay aligned to grid - mask_overlays = VGroup() + # causal mask + softmax — zero out future tokens, recompute weights + causal_txt = Text("causal mask + softmax\n(future tokens → 0)", font_size=11, color=RED) \ + .next_to(cells[n - 1], UP, buff=0.25).align_to(cells[n - 1], RIGHT) + anims = [sq.animate.set_fill(DARK_GRAY, 0.15) for sq in masked_cells] for i in range(n): for j in range(n): - if j > i: - x = grid_left + j * (cell_size + gap) + cell_size / 2 - y = grid_top - i * (cell_size + gap) - cell_size / 2 - sq = Square( - side_length=cell_size, fill_color=RED, - fill_opacity=0.10, stroke_width=0.5, - stroke_color=RED, stroke_opacity=0.3, - ) - sq.move_to([x, y, 0]) - mask_overlays.add(sq) - causal_txt = Text("causal mask\n(future tokens hidden)", font_size=11, color=RED) \ - .next_to(cells[6], UP, buff=0.25).align_to(cells[6], RIGHT) - self.play(FadeIn(mask_overlays), Write(causal_txt)) - self.wait(1.5) - self.play(FadeOut(mask_overlays), FadeOut(causal_txt)) + if j <= i: + idx = i * n + j + aw = post_weights[i][j] + aw_normed = (aw - post_min) / (post_max - post_min) + aw_color = interpolate_color(BLUE, RED, aw_normed) + anims.append(cells[idx].animate.set_fill(aw_color, 0.75)) + self.play(*anims, Write(causal_txt)) + self.wait(1.2) + self.play(FadeOut(causal_txt)) # highlight key patterns h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04) @@ -368,6 +374,196 @@ class Transformer(Scene): FadeOut(q_label), FadeOut(k_label), ) + # ═══════════════════════════════════════════════════ + # Auto-regressive Generation Demo (v2: full I/O pipeline) + def tok_card(text, fill=DARK_BLUE, stroke=GRAY): + t = Text(text, font_size=10, color=WHITE, weight=BOLD) + box = RoundedRectangle( + width=t.width + 0.2, height=t.height + 0.1, + corner_radius=0.04, fill_color=fill, fill_opacity=0.5, + stroke_color=stroke, stroke_width=0.5, + ) + t.move_to(box) + return VGroup(box, t) + + gen_tokens = ["", "The", "cat", "sat", "on", "the", "mat"] + + gen_title = Text("Auto-regressive Generation", font_size=34, color=BLUE) + gen_title.to_edge(UP, buff=0.35) + self.play(Write(gen_title)) + + # ── Layout constants ── + BLK_W = 1.8 + BLK_H = 0.28 + TFR_H = 0.48 + CX_BLK = -1.0 + Y_EMB = 1.40 + Y_NORM1 = 0.95 + Y_TFR = 0.32 + Y_NORM2 = -0.25 + Y_HEAD = -0.70 + + def mkblk(w, h, y, txt, color, fs=10): + b = RoundedRectangle(width=w, height=h, corner_radius=0.06, + fill_color=DARK_BLUE, fill_opacity=0.25, + stroke_color=color, stroke_width=1.5) + l = Text(txt, font_size=fs, color=color, weight=BOLD).move_to(b) + return VGroup(b, l).move_to([CX_BLK, y, 0]) + + emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9) + norm1 = mkblk(BLK_W, BLK_H, Y_NORM1, "RMS Norm", GREEN, 9) + tfr_node = mkblk(BLK_W, TFR_H, Y_TFR, "Transformer", PURPLE, 11) + norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9) + head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9) + pipeline = VGroup(emb_node, norm1, tfr_node, norm2, head) + + # Arrows between blocks + arrows = VGroup() + for a, b in zip(pipeline[:-1], pipeline[1:]): + arr = Arrow(a.get_bottom(), b.get_top(), + color=GRAY, stroke_width=1.2, tip_length=0.07) + arrows.add(arr) + + # ── KV Cache (right of Transformer) ── + KV_X = 1.8 + kv_size = 0.16 + kv_gap = 0.04 + K_Y = Y_TFR + 0.05 + V_Y = Y_TFR - 0.22 + cache_lbl = Text("KV Cache", font_size=8, color=GRAY).move_to([KV_X, Y_TFR + TFR_H / 2 + 0.2, 0]) + k_hdr = Text("K:", font_size=7, color=YELLOW).move_to([KV_X - 0.65, K_Y, 0]) + v_hdr = Text("V:", font_size=7, color=ORANGE).move_to([KV_X - 0.65, V_Y, 0]) + + kv_k = VGroup() + kv_v = VGroup() + k0 = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4, + stroke_color=YELLOW, stroke_width=0.5).move_to([KV_X, K_Y, 0]) + v0 = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4, + stroke_color=ORANGE, stroke_width=0.5).move_to([KV_X, V_Y, 0]) + kv_k.add(k0); kv_v.add(v0) + kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v) + + # ── Distribution builder ── + def build_dist(probs, y_center, max_w=2.8): + bars = VGroup(); lbls = VGroup() + bh = 0.12; bg = 0.02; lx = CX_BLK - max_w / 2 + items = list(probs.items()) + n = len(items) + y_top = y_center + (n * bh + (n - 1) * bg) / 2 + for i, (tok, pct) in enumerate(items): + w = max_w * pct / 100 + y = y_top - i * (bh + bg) + bar = Rectangle(width=w, height=bh, + fill_color=interpolate_color(BLUE, RED, pct / 100), + fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3) + bar.move_to([lx + w / 2, y, 0]) + lbl = Text(f"{tok} {pct}%", font_size=7, color=WHITE) + lbl.next_to(bar, RIGHT, buff=0.05) + bars.add(bar); lbls.add(lbl) + return VGroup(bars, lbls) + + dists = [ + {"The": 72, "cat": 8, "sat": 6, "on": 4, "": 10}, + {"cat": 65, "sat": 12, "was": 8, "is": 6, "": 9}, + {"sat": 58, "slept": 15, "ran": 8, "jumped": 6, "": 13}, + {"on": 55, "down": 12, "quietly": 8, "and": 6, "": 19}, + {"the": 60, "a": 12, "top": 8, "floor": 5, "": 15}, + {"mat": 50, "table": 15, "chair": 8, "floor": 6, "": 21}, + ] + Y_DIST = -1.45 + + # ── Token sequence row ── + SX = -4.0; Y_SEQ = 2.0; GAP = 0.55 + seq = VGroup() + sos = tok_card("", BLUE, BLUE).move_to([SX, Y_SEQ, 0]) + seq.add(sos) + + # Step label (below everything) + step_lbl = Text("Step 0 — [] → ?", font_size=9, color=GRAY) + step_lbl.move_to([0, -2.1, 0]) + + # ── Show static elements ── + self.play(FadeIn(pipeline), FadeIn(arrows)) + self.play(FadeIn(kv_group)) + self.play(FadeIn(sos)) + self.play(Write(step_lbl)) + self.wait(0.5) + + # ── Generation loop ── + for i, tok in enumerate(gen_tokens[1:], start=1): + input_str = " ".join(gen_tokens[:i + 1]) + + # 1. Highlight last token (the input being processed) + last = seq[-1] + hl = SurroundingRectangle(last, color=YELLOW, stroke_width=2, buff=0.04) + self.play(Create(hl), run_time=0.12) + + # 2. Arrow from last token → Embedding + in_arr = Arrow(last.get_bottom(), emb_node.get_top(), color=YELLOW, + stroke_width=2, tip_length=0.08) + self.play(FadeIn(in_arr, scale=0.5), run_time=0.1) + + # 3. Cascade through pipeline + pipes = [emb_node, norm1, tfr_node, norm2, head] + self.play( + *[p[0].animate.set_fill_opacity(0.6) for p in pipes], + run_time=0.12 + ) + self.play( + *[p[0].animate.set_fill_opacity(0.25) for p in pipes], + run_time=0.1 + ) + self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08) + + # 4. Show probability distribution + dist_arr = Arrow(head.get_bottom(), head.get_bottom() + DOWN * 0.4, + color=GRAY, stroke_width=1, tip_length=0.06) + dist = build_dist(dists[i - 1], Y_DIST) + self.play(FadeIn(dist_arr, scale=0.5), FadeIn(dist, scale=0.8), run_time=0.25) + + # 5. Sampling: highlight top bar + top_bar = dist[0][0] + sample_hl = SurroundingRectangle(top_bar, color=YELLOW, stroke_width=1.5, buff=0.02) + samp_lbl = Text("\u2713 argmax", font_size=7, color=YELLOW) + samp_lbl.next_to(dist, DOWN, buff=0.05) + self.play(Create(sample_hl), Write(samp_lbl), run_time=0.2) + self.wait(0.15) + + # 6. Predicted token appears below distribution + pred = tok_card(tok, YELLOW, YELLOW) + target_x = SX + len(seq) * GAP + pred.move_to([CX_BLK, Y_DIST - 0.45, 0]) + self.play(FadeIn(pred, scale=0.4), run_time=0.15) + + # 7. Token rises to join sequence at top + self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3) + self.play( + pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5), + pred[1].animate.set_color(WHITE), + FadeOut(sample_hl), FadeOut(samp_lbl), + ) + seq.add(pred) + + # 8. KV Cache: add K,V of this predicted token + kv_x = KV_X + i * (kv_size + kv_gap) + k_sq = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4, + stroke_color=YELLOW, stroke_width=0.5).move_to([kv_x, K_Y, 0]) + v_sq = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4, + stroke_color=ORANGE, stroke_width=0.5).move_to([kv_x, V_Y, 0]) + self.play(FadeIn(k_sq, scale=1.5), FadeIn(v_sq, scale=1.5), run_time=0.15) + kv_k.add(k_sq); kv_v.add(v_sq) + + # 9. Remove distribution and arrow + self.play(FadeOut(dist), FadeOut(dist_arr), run_time=0.08) + + # 10. Update step label + new_lbl = Text(f"Step {i} — [{input_str}]", font_size=9, color=GRAY) + new_lbl.move_to(step_lbl) + self.play(Transform(step_lbl, new_lbl)) + self.wait(0.3) + + self.wait(2.0) + def orth_line(start, end, color=GRAY): """Create an L-shaped orthogonal line from start to end."""