refactor: transformer — heatmap two-phase scores+mask, auto-regressive full I/O pipeline with Emb, RMS Norm, LM Head, distribution
This commit is contained in:
parent
496f964979
commit
0018868ee3
272
transformer.py
272
transformer.py
|
|
@ -6,6 +6,7 @@ Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow line
|
||||||
|
|
||||||
from manim import *
|
from manim import *
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
class Transformer(Scene):
|
class Transformer(Scene):
|
||||||
|
|
@ -271,39 +272,48 @@ class Transformer(Scene):
|
||||||
grid_left = -grid_high / 2
|
grid_left = -grid_high / 2
|
||||||
grid_top = 1.4
|
grid_top = 1.4
|
||||||
|
|
||||||
# attention weights (after softmax + causal mask)
|
# pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased
|
||||||
weights = [
|
pre_scores = [
|
||||||
[1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
|
[2.8, 1.5, 0.3, 0.1, 0.0, 0.0, 0.0],
|
||||||
[0.05, 0.95, 0.00, 0.00, 0.00, 0.00, 0.00],
|
[1.2, 3.5, 1.8, 0.5, 0.2, 0.1, 0.0],
|
||||||
[0.02, 0.20, 0.78, 0.00, 0.00, 0.00, 0.00],
|
[0.4, 2.0, 3.0, 1.5, 0.6, 0.2, 0.1],
|
||||||
[0.01, 0.05, 0.40, 0.54, 0.00, 0.00, 0.00],
|
[0.1, 0.6, 2.5, 2.8, 1.2, 0.4, 0.1],
|
||||||
[0.00, 0.02, 0.07, 0.35, 0.56, 0.00, 0.00],
|
[0.0, 0.2, 0.8, 2.0, 2.5, 1.5, 0.3],
|
||||||
[0.00, 0.01, 0.03, 0.10, 0.30, 0.56, 0.00],
|
[0.0, 0.1, 0.3, 0.9, 1.8, 2.5, 1.2],
|
||||||
[0.00, 0.00, 0.01, 0.05, 0.12, 0.35, 0.47],
|
[0.0, 0.0, 0.1, 0.4, 0.8, 1.5, 3.0],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# compute post-softmax weights with causal mask (j > i → -inf)
|
||||||
|
post_weights = []
|
||||||
|
for i in range(n):
|
||||||
|
row = pre_scores[i]
|
||||||
|
masked = [-float('inf') if j > i else row[j] for j in range(n)]
|
||||||
|
exps = [math.exp(v) for v in masked]
|
||||||
|
exp_sum = sum(exps)
|
||||||
|
post_weights.append([e / exp_sum for e in exps])
|
||||||
|
|
||||||
|
flat_pre = [w for row in pre_scores for w in row]
|
||||||
|
pre_min, pre_max = min(flat_pre), max(flat_pre)
|
||||||
|
flat_post = [w for row in post_weights for w in row]
|
||||||
|
post_min, post_max = min(flat_post), max(flat_post)
|
||||||
cells = VGroup()
|
cells = VGroup()
|
||||||
|
masked_cells = VGroup()
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(n):
|
for j in range(n):
|
||||||
w = weights[i][j]
|
pw = pre_scores[i][j]
|
||||||
if j > i:
|
pw_normed = (pw - pre_min) / (pre_max - pre_min)
|
||||||
color = DARK_GRAY
|
pw_color = interpolate_color(BLUE, RED, pw_normed)
|
||||||
fill_op = 0.15
|
|
||||||
elif w < 0.001:
|
|
||||||
color = DARKER_GRAY
|
|
||||||
fill_op = 0.2
|
|
||||||
else:
|
|
||||||
color = interpolate_color(BLUE, RED, w)
|
|
||||||
fill_op = 0.75
|
|
||||||
sq = Square(
|
sq = Square(
|
||||||
side_length=cell_size, fill_color=color,
|
side_length=cell_size, fill_color=pw_color,
|
||||||
fill_opacity=fill_op, stroke_width=0.5,
|
fill_opacity=0.75, stroke_width=0.5,
|
||||||
stroke_color=GRAY,
|
stroke_color=GRAY,
|
||||||
)
|
)
|
||||||
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
||||||
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
||||||
sq.move_to([x, y, 0])
|
sq.move_to([x, y, 0])
|
||||||
cells.add(sq)
|
cells.add(sq)
|
||||||
|
if j > i:
|
||||||
|
masked_cells.add(sq)
|
||||||
self.play(FadeIn(sq, scale=0.6), run_time=0.015)
|
self.play(FadeIn(sq, scale=0.6), run_time=0.015)
|
||||||
|
|
||||||
# row labels (query) on the left
|
# row labels (query) on the left
|
||||||
|
|
@ -329,25 +339,21 @@ class Transformer(Scene):
|
||||||
self.play(*[Write(l) for l in col_lbls], Write(k_label))
|
self.play(*[Write(l) for l in col_lbls], Write(k_label))
|
||||||
self.wait(1.0)
|
self.wait(1.0)
|
||||||
|
|
||||||
# causal mask — per-cell red overlay aligned to grid
|
# causal mask + softmax — zero out future tokens, recompute weights
|
||||||
mask_overlays = VGroup()
|
causal_txt = Text("causal mask + softmax\n(future tokens → 0)", font_size=11, color=RED) \
|
||||||
|
.next_to(cells[n - 1], UP, buff=0.25).align_to(cells[n - 1], RIGHT)
|
||||||
|
anims = [sq.animate.set_fill(DARK_GRAY, 0.15) for sq in masked_cells]
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(n):
|
for j in range(n):
|
||||||
if j > i:
|
if j <= i:
|
||||||
x = grid_left + j * (cell_size + gap) + cell_size / 2
|
idx = i * n + j
|
||||||
y = grid_top - i * (cell_size + gap) - cell_size / 2
|
aw = post_weights[i][j]
|
||||||
sq = Square(
|
aw_normed = (aw - post_min) / (post_max - post_min)
|
||||||
side_length=cell_size, fill_color=RED,
|
aw_color = interpolate_color(BLUE, RED, aw_normed)
|
||||||
fill_opacity=0.10, stroke_width=0.5,
|
anims.append(cells[idx].animate.set_fill(aw_color, 0.75))
|
||||||
stroke_color=RED, stroke_opacity=0.3,
|
self.play(*anims, Write(causal_txt))
|
||||||
)
|
self.wait(1.2)
|
||||||
sq.move_to([x, y, 0])
|
self.play(FadeOut(causal_txt))
|
||||||
mask_overlays.add(sq)
|
|
||||||
causal_txt = Text("causal mask\n(future tokens hidden)", font_size=11, color=RED) \
|
|
||||||
.next_to(cells[6], UP, buff=0.25).align_to(cells[6], RIGHT)
|
|
||||||
self.play(FadeIn(mask_overlays), Write(causal_txt))
|
|
||||||
self.wait(1.5)
|
|
||||||
self.play(FadeOut(mask_overlays), FadeOut(causal_txt))
|
|
||||||
|
|
||||||
# highlight key patterns
|
# highlight key patterns
|
||||||
h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04)
|
h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04)
|
||||||
|
|
@ -368,6 +374,196 @@ class Transformer(Scene):
|
||||||
FadeOut(q_label), FadeOut(k_label),
|
FadeOut(q_label), FadeOut(k_label),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
# Auto-regressive Generation Demo (v2: full I/O pipeline)
|
||||||
|
def tok_card(text, fill=DARK_BLUE, stroke=GRAY):
|
||||||
|
t = Text(text, font_size=10, color=WHITE, weight=BOLD)
|
||||||
|
box = RoundedRectangle(
|
||||||
|
width=t.width + 0.2, height=t.height + 0.1,
|
||||||
|
corner_radius=0.04, fill_color=fill, fill_opacity=0.5,
|
||||||
|
stroke_color=stroke, stroke_width=0.5,
|
||||||
|
)
|
||||||
|
t.move_to(box)
|
||||||
|
return VGroup(box, t)
|
||||||
|
|
||||||
|
gen_tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
|
||||||
|
|
||||||
|
gen_title = Text("Auto-regressive Generation", font_size=34, color=BLUE)
|
||||||
|
gen_title.to_edge(UP, buff=0.35)
|
||||||
|
self.play(Write(gen_title))
|
||||||
|
|
||||||
|
# ── Layout constants ──
|
||||||
|
BLK_W = 1.8
|
||||||
|
BLK_H = 0.28
|
||||||
|
TFR_H = 0.48
|
||||||
|
CX_BLK = -1.0
|
||||||
|
Y_EMB = 1.40
|
||||||
|
Y_NORM1 = 0.95
|
||||||
|
Y_TFR = 0.32
|
||||||
|
Y_NORM2 = -0.25
|
||||||
|
Y_HEAD = -0.70
|
||||||
|
|
||||||
|
def mkblk(w, h, y, txt, color, fs=10):
|
||||||
|
b = RoundedRectangle(width=w, height=h, corner_radius=0.06,
|
||||||
|
fill_color=DARK_BLUE, fill_opacity=0.25,
|
||||||
|
stroke_color=color, stroke_width=1.5)
|
||||||
|
l = Text(txt, font_size=fs, color=color, weight=BOLD).move_to(b)
|
||||||
|
return VGroup(b, l).move_to([CX_BLK, y, 0])
|
||||||
|
|
||||||
|
emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9)
|
||||||
|
norm1 = mkblk(BLK_W, BLK_H, Y_NORM1, "RMS Norm", GREEN, 9)
|
||||||
|
tfr_node = mkblk(BLK_W, TFR_H, Y_TFR, "Transformer", PURPLE, 11)
|
||||||
|
norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9)
|
||||||
|
head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9)
|
||||||
|
pipeline = VGroup(emb_node, norm1, tfr_node, norm2, head)
|
||||||
|
|
||||||
|
# Arrows between blocks
|
||||||
|
arrows = VGroup()
|
||||||
|
for a, b in zip(pipeline[:-1], pipeline[1:]):
|
||||||
|
arr = Arrow(a.get_bottom(), b.get_top(),
|
||||||
|
color=GRAY, stroke_width=1.2, tip_length=0.07)
|
||||||
|
arrows.add(arr)
|
||||||
|
|
||||||
|
# ── KV Cache (right of Transformer) ──
|
||||||
|
KV_X = 1.8
|
||||||
|
kv_size = 0.16
|
||||||
|
kv_gap = 0.04
|
||||||
|
K_Y = Y_TFR + 0.05
|
||||||
|
V_Y = Y_TFR - 0.22
|
||||||
|
cache_lbl = Text("KV Cache", font_size=8, color=GRAY).move_to([KV_X, Y_TFR + TFR_H / 2 + 0.2, 0])
|
||||||
|
k_hdr = Text("K:", font_size=7, color=YELLOW).move_to([KV_X - 0.65, K_Y, 0])
|
||||||
|
v_hdr = Text("V:", font_size=7, color=ORANGE).move_to([KV_X - 0.65, V_Y, 0])
|
||||||
|
|
||||||
|
kv_k = VGroup()
|
||||||
|
kv_v = VGroup()
|
||||||
|
k0 = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4,
|
||||||
|
stroke_color=YELLOW, stroke_width=0.5).move_to([KV_X, K_Y, 0])
|
||||||
|
v0 = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4,
|
||||||
|
stroke_color=ORANGE, stroke_width=0.5).move_to([KV_X, V_Y, 0])
|
||||||
|
kv_k.add(k0); kv_v.add(v0)
|
||||||
|
kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v)
|
||||||
|
|
||||||
|
# ── Distribution builder ──
|
||||||
|
def build_dist(probs, y_center, max_w=2.8):
|
||||||
|
bars = VGroup(); lbls = VGroup()
|
||||||
|
bh = 0.12; bg = 0.02; lx = CX_BLK - max_w / 2
|
||||||
|
items = list(probs.items())
|
||||||
|
n = len(items)
|
||||||
|
y_top = y_center + (n * bh + (n - 1) * bg) / 2
|
||||||
|
for i, (tok, pct) in enumerate(items):
|
||||||
|
w = max_w * pct / 100
|
||||||
|
y = y_top - i * (bh + bg)
|
||||||
|
bar = Rectangle(width=w, height=bh,
|
||||||
|
fill_color=interpolate_color(BLUE, RED, pct / 100),
|
||||||
|
fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3)
|
||||||
|
bar.move_to([lx + w / 2, y, 0])
|
||||||
|
lbl = Text(f"{tok} {pct}%", font_size=7, color=WHITE)
|
||||||
|
lbl.next_to(bar, RIGHT, buff=0.05)
|
||||||
|
bars.add(bar); lbls.add(lbl)
|
||||||
|
return VGroup(bars, lbls)
|
||||||
|
|
||||||
|
dists = [
|
||||||
|
{"The": 72, "cat": 8, "sat": 6, "on": 4, "<unk>": 10},
|
||||||
|
{"cat": 65, "sat": 12, "was": 8, "is": 6, "<unk>": 9},
|
||||||
|
{"sat": 58, "slept": 15, "ran": 8, "jumped": 6, "<unk>": 13},
|
||||||
|
{"on": 55, "down": 12, "quietly": 8, "and": 6, "<unk>": 19},
|
||||||
|
{"the": 60, "a": 12, "top": 8, "floor": 5, "<unk>": 15},
|
||||||
|
{"mat": 50, "table": 15, "chair": 8, "floor": 6, "<unk>": 21},
|
||||||
|
]
|
||||||
|
Y_DIST = -1.45
|
||||||
|
|
||||||
|
# ── Token sequence row ──
|
||||||
|
SX = -4.0; Y_SEQ = 2.0; GAP = 0.55
|
||||||
|
seq = VGroup()
|
||||||
|
sos = tok_card("<s>", BLUE, BLUE).move_to([SX, Y_SEQ, 0])
|
||||||
|
seq.add(sos)
|
||||||
|
|
||||||
|
# Step label (below everything)
|
||||||
|
step_lbl = Text("Step 0 — [<s>] → ?", font_size=9, color=GRAY)
|
||||||
|
step_lbl.move_to([0, -2.1, 0])
|
||||||
|
|
||||||
|
# ── Show static elements ──
|
||||||
|
self.play(FadeIn(pipeline), FadeIn(arrows))
|
||||||
|
self.play(FadeIn(kv_group))
|
||||||
|
self.play(FadeIn(sos))
|
||||||
|
self.play(Write(step_lbl))
|
||||||
|
self.wait(0.5)
|
||||||
|
|
||||||
|
# ── Generation loop ──
|
||||||
|
for i, tok in enumerate(gen_tokens[1:], start=1):
|
||||||
|
input_str = " ".join(gen_tokens[:i + 1])
|
||||||
|
|
||||||
|
# 1. Highlight last token (the input being processed)
|
||||||
|
last = seq[-1]
|
||||||
|
hl = SurroundingRectangle(last, color=YELLOW, stroke_width=2, buff=0.04)
|
||||||
|
self.play(Create(hl), run_time=0.12)
|
||||||
|
|
||||||
|
# 2. Arrow from last token → Embedding
|
||||||
|
in_arr = Arrow(last.get_bottom(), emb_node.get_top(), color=YELLOW,
|
||||||
|
stroke_width=2, tip_length=0.08)
|
||||||
|
self.play(FadeIn(in_arr, scale=0.5), run_time=0.1)
|
||||||
|
|
||||||
|
# 3. Cascade through pipeline
|
||||||
|
pipes = [emb_node, norm1, tfr_node, norm2, head]
|
||||||
|
self.play(
|
||||||
|
*[p[0].animate.set_fill_opacity(0.6) for p in pipes],
|
||||||
|
run_time=0.12
|
||||||
|
)
|
||||||
|
self.play(
|
||||||
|
*[p[0].animate.set_fill_opacity(0.25) for p in pipes],
|
||||||
|
run_time=0.1
|
||||||
|
)
|
||||||
|
self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08)
|
||||||
|
|
||||||
|
# 4. Show probability distribution
|
||||||
|
dist_arr = Arrow(head.get_bottom(), head.get_bottom() + DOWN * 0.4,
|
||||||
|
color=GRAY, stroke_width=1, tip_length=0.06)
|
||||||
|
dist = build_dist(dists[i - 1], Y_DIST)
|
||||||
|
self.play(FadeIn(dist_arr, scale=0.5), FadeIn(dist, scale=0.8), run_time=0.25)
|
||||||
|
|
||||||
|
# 5. Sampling: highlight top bar
|
||||||
|
top_bar = dist[0][0]
|
||||||
|
sample_hl = SurroundingRectangle(top_bar, color=YELLOW, stroke_width=1.5, buff=0.02)
|
||||||
|
samp_lbl = Text("\u2713 argmax", font_size=7, color=YELLOW)
|
||||||
|
samp_lbl.next_to(dist, DOWN, buff=0.05)
|
||||||
|
self.play(Create(sample_hl), Write(samp_lbl), run_time=0.2)
|
||||||
|
self.wait(0.15)
|
||||||
|
|
||||||
|
# 6. Predicted token appears below distribution
|
||||||
|
pred = tok_card(tok, YELLOW, YELLOW)
|
||||||
|
target_x = SX + len(seq) * GAP
|
||||||
|
pred.move_to([CX_BLK, Y_DIST - 0.45, 0])
|
||||||
|
self.play(FadeIn(pred, scale=0.4), run_time=0.15)
|
||||||
|
|
||||||
|
# 7. Token rises to join sequence at top
|
||||||
|
self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3)
|
||||||
|
self.play(
|
||||||
|
pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5),
|
||||||
|
pred[1].animate.set_color(WHITE),
|
||||||
|
FadeOut(sample_hl), FadeOut(samp_lbl),
|
||||||
|
)
|
||||||
|
seq.add(pred)
|
||||||
|
|
||||||
|
# 8. KV Cache: add K,V of this predicted token
|
||||||
|
kv_x = KV_X + i * (kv_size + kv_gap)
|
||||||
|
k_sq = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4,
|
||||||
|
stroke_color=YELLOW, stroke_width=0.5).move_to([kv_x, K_Y, 0])
|
||||||
|
v_sq = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4,
|
||||||
|
stroke_color=ORANGE, stroke_width=0.5).move_to([kv_x, V_Y, 0])
|
||||||
|
self.play(FadeIn(k_sq, scale=1.5), FadeIn(v_sq, scale=1.5), run_time=0.15)
|
||||||
|
kv_k.add(k_sq); kv_v.add(v_sq)
|
||||||
|
|
||||||
|
# 9. Remove distribution and arrow
|
||||||
|
self.play(FadeOut(dist), FadeOut(dist_arr), run_time=0.08)
|
||||||
|
|
||||||
|
# 10. Update step label
|
||||||
|
new_lbl = Text(f"Step {i} — [{input_str}]", font_size=9, color=GRAY)
|
||||||
|
new_lbl.move_to(step_lbl)
|
||||||
|
self.play(Transform(step_lbl, new_lbl))
|
||||||
|
self.wait(0.3)
|
||||||
|
|
||||||
|
self.wait(2.0)
|
||||||
|
|
||||||
|
|
||||||
def orth_line(start, end, color=GRAY):
|
def orth_line(start, end, color=GRAY):
|
||||||
"""Create an L-shaped orthogonal line from start to end."""
|
"""Create an L-shaped orthogonal line from start to end."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue