video-promo/transformer.py

574 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""AstrAI promo: Transformer GQA attention animation.
Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow lines:
Input → Q/K/V Projections → Repeat KV → SDPA → O Projection → Output
"""
from manim import *
import numpy as np
import math
class Transformer(Scene):
"""Animates the GQA attention mechanism with orthogonal connection lines."""
def construct(self):
title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE)
title.to_edge(UP, buff=0.35)
self.play(Write(title))
# ── Helper: box ──
def mk(name, color, w=2.6, h=0.72, fs=10):
box = Rectangle(
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
)
lbl = Text(name, font_size=fs, color=color)
return VGroup(box, lbl)
# ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp.move_to(UP * 2.6)
y1 = 1.5
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
q_grp.move_to(LEFT * 3.0 + UP * y1)
k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1)
y2 = 0.0
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp.move_to(UP * y2)
y3 = -1.6
sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
)
sdpa_grp.move_to(UP * y3)
y4 = -2.9
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp.move_to(UP * y4)
out = Text("x' (hidden states)", font_size=15, color=GRAY)
out.next_to(o_grp, DOWN, buff=0.4)
# ── Animate boxes ──
self.play(Write(inp))
all_boxes = [q_grp, k_grp, v_grp, repeat_grp, sdpa_grp, o_grp]
for g in all_boxes:
self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2)
# ── Input trunk → branch → Q/K/V (enter from directly above) ──
trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0])
trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5)
self.play(Create(trunk), run_time=0.15)
branch_left = Line(
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
color=GRAY, stroke_width=1.5,
)
branch_right = Line(
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
color=GRAY, stroke_width=1.5,
)
self.play(Create(branch_left), Create(branch_right), run_time=0.2)
drop_q = Line(
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
q_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
drop_k = Line(
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
k_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
drop_v = Line(
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
v_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
for ln in [drop_q, drop_k, drop_v]:
self.play(Create(ln), run_time=0.12)
input_lines = VGroup(trunk, branch_left, branch_right, drop_q, drop_k, drop_v)
# ── K/V → Repeat KV (trunk-branch, enter from above) ──
kv_junc_y = repeat_grp.get_top()[1] + 0.3
drop_k2 = Line(
k_grp.get_bottom(),
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
drop_v2 = Line(
v_grp.get_bottom(),
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
kv_branch = Line(
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
color=GRAY, stroke_width=1.5,
)
kv_trunk = Line(
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
repeat_grp.get_top(),
color=GRAY, stroke_width=1.5,
)
kv_lines = VGroup(drop_k2, drop_v2, kv_branch, kv_trunk)
self.play(Create(kv_lines), run_time=0.3)
# ── Q → SDPA (bypasses Repeat KV, from above) ──
qs_junc_y = sdpa_grp.get_top()[1] + 0.3
line_qs = VMobject(color=GRAY, stroke_width=1.5)
line_qs.set_points_as_corners([
q_grp.get_bottom(),
np.array([q_grp.get_bottom()[0], qs_junc_y, 0]),
np.array([sdpa_grp.get_top()[0], qs_junc_y, 0]),
sdpa_grp.get_top(),
])
self.play(Create(line_qs), run_time=0.15)
line_rs = orth_line(repeat_grp.get_bottom(), sdpa_grp.get_top(), GRAY)
self.play(Create(line_rs), run_time=0.15)
line_so = orth_line(sdpa_grp.get_bottom(), o_grp.get_top(), GRAY)
self.play(Create(line_so), run_time=0.15)
line_oo = orth_line(o_grp.get_bottom(), out.get_top(), GRAY)
self.play(Create(line_oo), run_time=0.15)
self.play(Write(out))
self.wait(0.4)
all_lines = VGroup(
input_lines, kv_lines, line_qs,
line_rs, line_so, line_oo,
)
# ── RoPE highlight ──
rope_q = SurroundingRectangle(q_grp, color=TEAL, buff=0.12)
rope_k = SurroundingRectangle(k_grp, color=TEAL, buff=0.12)
rope_t = Text(
"RoPE: rotary position encoding\napplied to Q and K",
font_size=13, color=TEAL,
)
rope_t.next_to(VGroup(rope_q, rope_k), UP, buff=0.25)
self.play(Create(rope_q), Create(rope_k), Write(rope_t))
self.wait(1.5)
self.play(FadeOut(rope_q), FadeOut(rope_k), FadeOut(rope_t))
# ── GQA ratio highlight ──
gqa_h = SurroundingRectangle(
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
)
gqa_t = Text(
"GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%",
font_size=13, color=YELLOW,
)
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
self.play(Create(gqa_h), Write(gqa_t))
self.wait(1.8)
self.play(FadeOut(gqa_h), FadeOut(gqa_t))
# ── Repeat KV highlight ──
kv_h = SurroundingRectangle(
VGroup(k_grp, v_grp), color=GREEN, buff=0.12
)
kv_t = Text(
"repeat_kv(): broadcast\n4 heads → 24 heads",
font_size=12, color=GREEN,
)
kv_t.next_to(kv_h, RIGHT, buff=0.5)
self.play(Create(kv_h), Write(kv_t))
self.wait(1.5)
# ── Fade all ──
self.play(
*[FadeOut(g) for g in all_boxes],
FadeOut(all_lines),
FadeOut(kv_h), FadeOut(kv_t),
FadeOut(inp), FadeOut(out), FadeOut(title),
)
# ═══════════════════════════════════════════════════
# 12. Scaled Dot-Product Attention — full formula + breakdown
# ═══════════════════════════════════════════════════
qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE)
qkv_title.to_edge(UP, buff=0.35)
self.play(Write(qkv_title))
# Full formula first, stays on screen
full_eq = MathTex(
r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V",
font_size=36, color=WHITE,
)
full_eq.next_to(qkv_title, DOWN, buff=0.5)
self.play(Write(full_eq))
self.wait(0.6)
# Q / K / V — brief meanings
q_txt = Text("Q = Query", font_size=18, color=YELLOW)
k_txt = Text("K = Key", font_size=18, color=ORANGE)
v_txt = Text("V = Value", font_size=18, color=GREEN)
qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5)
qkv_labels.next_to(full_eq, DOWN, buff=0.5)
self.play(Write(qkv_labels))
q_desc = Text("\"what am I looking for?\"", font_size=11, color=YELLOW) \
.next_to(q_txt, DOWN, buff=0.10)
k_desc = Text("\"what do I have?\"", font_size=11, color=ORANGE) \
.next_to(k_txt, DOWN, buff=0.10)
v_desc = Text("\"what do I contribute?\"", font_size=11, color=GREEN) \
.next_to(v_txt, DOWN, buff=0.10)
self.play(Write(q_desc), Write(k_desc), Write(v_desc))
self.wait(2.0)
self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc))
# Step-by-step decomposition — full formula stays visible above
steps = [
(MathTex(r"\text{(1) } S = QK^\top", font_size=26, color=YELLOW),
Text("score matrix — pairwise token similarity", font_size=12, color=GRAY)),
(MathTex(r"\text{(2) } S \mathbin{/} \sqrt{d_k}", font_size=26, color=ORANGE),
Text("scale — prevents gradient explosion", font_size=12, color=GRAY)),
(MathTex(r"\text{(3) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right)", font_size=26, color=GREEN),
Text("normalize — each row sums to 1 (probability)", font_size=12, color=GRAY)),
(MathTex(r"\text{(4) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right) \cdot V", font_size=26, color=BLUE),
Text("weighted sum — aggregate values by attention", font_size=12, color=GRAY)),
]
step_group = VGroup()
steps_mobj = VGroup()
for eq, desc in steps:
sg = VGroup(eq, desc).arrange(DOWN, buff=0.06)
step_group.add(sg)
steps_mobj.add(eq)
step_group.arrange(DOWN, buff=0.22, aligned_edge=LEFT)
step_group.next_to(full_eq, DOWN, buff=0.6)
for sg in step_group:
self.play(Write(sg), run_time=0.3)
self.wait(2.5)
self.play(FadeOut(step_group), FadeOut(full_eq))
# ═══════════════════════════════════════════════════
# 13. Attention score heatmap — concrete example
# ═══════════════════════════════════════════════════
hm_title = Text("Attention Score Heatmap", font_size=34, color=BLUE)
hm_title.to_edge(UP, buff=0.35)
hm_sub = Text("\"The cat sat on the mat\" — causal, per-token attention weights",
font_size=14, color=GRAY).next_to(hm_title, DOWN, buff=0.12)
self.play(FadeOut(qkv_title), Write(hm_title), Write(hm_sub))
tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
n = len(tokens)
cell_size = 0.52
gap = 0.04
grid_high = n * cell_size + (n - 1) * gap
grid_left = -grid_high / 2
grid_top = 1.4
# pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased
pre_scores = [
[2.8, 1.5, 0.3, 0.1, 0.0, 0.0, 0.0],
[1.2, 3.5, 1.8, 0.5, 0.2, 0.1, 0.0],
[0.4, 2.0, 3.0, 1.5, 0.6, 0.2, 0.1],
[0.1, 0.6, 2.5, 2.8, 1.2, 0.4, 0.1],
[0.0, 0.2, 0.8, 2.0, 2.5, 1.5, 0.3],
[0.0, 0.1, 0.3, 0.9, 1.8, 2.5, 1.2],
[0.0, 0.0, 0.1, 0.4, 0.8, 1.5, 3.0],
]
# compute post-softmax weights with causal mask (j > i → -inf)
post_weights = []
for i in range(n):
row = pre_scores[i]
masked = [-float('inf') if j > i else row[j] for j in range(n)]
exps = [math.exp(v) for v in masked]
exp_sum = sum(exps)
post_weights.append([e / exp_sum for e in exps])
flat_pre = [w for row in pre_scores for w in row]
pre_min, pre_max = min(flat_pre), max(flat_pre)
flat_post = [w for row in post_weights for w in row]
post_min, post_max = min(flat_post), max(flat_post)
cells = VGroup()
masked_cells = VGroup()
for i in range(n):
for j in range(n):
pw = pre_scores[i][j]
pw_normed = (pw - pre_min) / (pre_max - pre_min)
pw_color = interpolate_color(BLUE, RED, pw_normed)
sq = Square(
side_length=cell_size, fill_color=pw_color,
fill_opacity=0.75, stroke_width=0.5,
stroke_color=GRAY,
)
x = grid_left + j * (cell_size + gap) + cell_size / 2
y = grid_top - i * (cell_size + gap) - cell_size / 2
sq.move_to([x, y, 0])
cells.add(sq)
if j > i:
masked_cells.add(sq)
self.play(FadeIn(sq, scale=0.6), run_time=0.015)
# row labels (query) on the left
row_lbls = VGroup()
for i, tok in enumerate(tokens):
lbl = Text(tok, font_size=12, color=GRAY)
y = grid_top - i * (cell_size + gap) - cell_size / 2
lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08)
row_lbls.add(lbl)
q_label = Text("Q", font_size=11, color=WHITE, weight=BOLD)
q_label.move_to(row_lbls[0].get_left() + LEFT * 0.3).shift(UP * 0.15)
self.play(*[Write(l) for l in row_lbls], Write(q_label))
# column labels (key) on top
col_lbls = VGroup()
for j, tok in enumerate(tokens):
lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6)
x = grid_left + j * (cell_size + gap) + cell_size / 2
lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04)
col_lbls.add(lbl)
k_label = Text("K", font_size=11, color=WHITE, weight=BOLD)
k_label.next_to(col_lbls[0], UP, buff=0.06)
self.play(*[Write(l) for l in col_lbls], Write(k_label))
self.wait(1.0)
# causal mask + softmax — zero out future tokens, recompute weights
causal_txt = Text("causal mask + softmax\n(future tokens → 0)", font_size=11, color=RED) \
.next_to(cells[n - 1], UP, buff=0.25).align_to(cells[n - 1], RIGHT)
anims = [sq.animate.set_fill(DARK_GRAY, 0.15) for sq in masked_cells]
for i in range(n):
for j in range(n):
if j <= i:
idx = i * n + j
aw = post_weights[i][j]
aw_normed = (aw - post_min) / (post_max - post_min)
aw_color = interpolate_color(BLUE, RED, aw_normed)
anims.append(cells[idx].animate.set_fill(aw_color, 0.75))
self.play(*anims, Write(causal_txt))
self.wait(1.2)
self.play(FadeOut(causal_txt))
# highlight key patterns
h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04)
h2 = SurroundingRectangle(cells[3 * n + 2], color=ORANGE, stroke_width=2, buff=0.04)
h3 = SurroundingRectangle(cells[4 * n + 3], color=ORANGE, stroke_width=2, buff=0.04)
h4 = SurroundingRectangle(cells[5 * n + 4], color=ORANGE, stroke_width=2, buff=0.04)
h5 = SurroundingRectangle(cells[6 * n + 5], color=ORANGE, stroke_width=2, buff=0.04)
hl_text = Text("previous token attends to next\n(causal sequence learning)", font_size=11, color=ORANGE) \
.next_to(cells[(n - 1) * n + (n - 1)], RIGHT, buff=0.8)
self.play(Create(h1), Create(h2), Create(h3), Create(h4), Create(h5), Write(hl_text))
self.wait(2.0)
self.play(FadeOut(h1), FadeOut(h2), FadeOut(h3), FadeOut(h4), FadeOut(h5), FadeOut(hl_text))
# fade all heatmap
self.play(
FadeOut(hm_title), FadeOut(hm_sub),
FadeOut(cells), FadeOut(row_lbls), FadeOut(col_lbls),
FadeOut(q_label), FadeOut(k_label),
)
# ═══════════════════════════════════════════════════
# Auto-regressive Generation Demo (v2: full I/O pipeline)
def tok_card(text, fill=DARK_BLUE, stroke=GRAY):
t = Text(text, font_size=10, color=WHITE, weight=BOLD)
box = RoundedRectangle(
width=t.width + 0.2, height=t.height + 0.1,
corner_radius=0.04, fill_color=fill, fill_opacity=0.5,
stroke_color=stroke, stroke_width=0.5,
)
t.move_to(box)
return VGroup(box, t)
gen_tokens = ["<s>", "The", "cat", "sat", "on", "the", "mat"]
gen_title = Text("Auto-regressive Generation", font_size=34, color=BLUE)
gen_title.to_edge(UP, buff=0.35)
self.play(Write(gen_title))
# ── Layout constants ──
BLK_W = 1.8
BLK_H = 0.28
TFR_H = 0.48
CX_BLK = -1.0
Y_EMB = 1.40
Y_NORM1 = 0.95
Y_TFR = 0.32
Y_NORM2 = -0.25
Y_HEAD = -0.70
def mkblk(w, h, y, txt, color, fs=10):
b = RoundedRectangle(width=w, height=h, corner_radius=0.06,
fill_color=DARK_BLUE, fill_opacity=0.25,
stroke_color=color, stroke_width=1.5)
l = Text(txt, font_size=fs, color=color, weight=BOLD).move_to(b)
return VGroup(b, l).move_to([CX_BLK, y, 0])
emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9)
norm1 = mkblk(BLK_W, BLK_H, Y_NORM1, "RMS Norm", GREEN, 9)
tfr_node = mkblk(BLK_W, TFR_H, Y_TFR, "Transformer", PURPLE, 11)
norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9)
head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9)
pipeline = VGroup(emb_node, norm1, tfr_node, norm2, head)
# Arrows between blocks
arrows = VGroup()
for a, b in zip(pipeline[:-1], pipeline[1:]):
arr = Arrow(a.get_bottom(), b.get_top(),
color=GRAY, stroke_width=1.2, tip_length=0.07)
arrows.add(arr)
# ── KV Cache (right of Transformer) ──
KV_X = 1.8
kv_size = 0.16
kv_gap = 0.04
K_Y = Y_TFR + 0.05
V_Y = Y_TFR - 0.22
cache_lbl = Text("KV Cache", font_size=8, color=GRAY).move_to([KV_X, Y_TFR + TFR_H / 2 + 0.2, 0])
k_hdr = Text("K:", font_size=7, color=YELLOW).move_to([KV_X - 0.65, K_Y, 0])
v_hdr = Text("V:", font_size=7, color=ORANGE).move_to([KV_X - 0.65, V_Y, 0])
kv_k = VGroup()
kv_v = VGroup()
k0 = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4,
stroke_color=YELLOW, stroke_width=0.5).move_to([KV_X, K_Y, 0])
v0 = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4,
stroke_color=ORANGE, stroke_width=0.5).move_to([KV_X, V_Y, 0])
kv_k.add(k0); kv_v.add(v0)
kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v)
# ── Distribution builder ──
def build_dist(probs, y_center, max_w=2.8):
bars = VGroup(); lbls = VGroup()
bh = 0.12; bg = 0.02; lx = CX_BLK - max_w / 2
items = list(probs.items())
n = len(items)
y_top = y_center + (n * bh + (n - 1) * bg) / 2
for i, (tok, pct) in enumerate(items):
w = max_w * pct / 100
y = y_top - i * (bh + bg)
bar = Rectangle(width=w, height=bh,
fill_color=interpolate_color(BLUE, RED, pct / 100),
fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3)
bar.move_to([lx + w / 2, y, 0])
lbl = Text(f"{tok} {pct}%", font_size=7, color=WHITE)
lbl.next_to(bar, RIGHT, buff=0.05)
bars.add(bar); lbls.add(lbl)
return VGroup(bars, lbls)
dists = [
{"The": 72, "cat": 8, "sat": 6, "on": 4, "<unk>": 10},
{"cat": 65, "sat": 12, "was": 8, "is": 6, "<unk>": 9},
{"sat": 58, "slept": 15, "ran": 8, "jumped": 6, "<unk>": 13},
{"on": 55, "down": 12, "quietly": 8, "and": 6, "<unk>": 19},
{"the": 60, "a": 12, "top": 8, "floor": 5, "<unk>": 15},
{"mat": 50, "table": 15, "chair": 8, "floor": 6, "<unk>": 21},
]
Y_DIST = -1.45
# ── Token sequence row ──
SX = -4.0; Y_SEQ = 2.0; GAP = 0.55
seq = VGroup()
sos = tok_card("<s>", BLUE, BLUE).move_to([SX, Y_SEQ, 0])
seq.add(sos)
# Step label (below everything)
step_lbl = Text("Step 0 — [<s>] → ?", font_size=9, color=GRAY)
step_lbl.move_to([0, -2.1, 0])
# ── Show static elements ──
self.play(FadeIn(pipeline), FadeIn(arrows))
self.play(FadeIn(kv_group))
self.play(FadeIn(sos))
self.play(Write(step_lbl))
self.wait(0.5)
# ── Generation loop ──
for i, tok in enumerate(gen_tokens[1:], start=1):
input_str = " ".join(gen_tokens[:i + 1])
# 1. Highlight last token (the input being processed)
last = seq[-1]
hl = SurroundingRectangle(last, color=YELLOW, stroke_width=2, buff=0.04)
self.play(Create(hl), run_time=0.12)
# 2. Arrow from last token → Embedding
in_arr = Arrow(last.get_bottom(), emb_node.get_top(), color=YELLOW,
stroke_width=2, tip_length=0.08)
self.play(FadeIn(in_arr, scale=0.5), run_time=0.1)
# 3. Cascade through pipeline
pipes = [emb_node, norm1, tfr_node, norm2, head]
self.play(
*[p[0].animate.set_fill_opacity(0.6) for p in pipes],
run_time=0.12
)
self.play(
*[p[0].animate.set_fill_opacity(0.25) for p in pipes],
run_time=0.1
)
self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08)
# 4. Show probability distribution
dist_arr = Arrow(head.get_bottom(), head.get_bottom() + DOWN * 0.22,
color=GRAY, stroke_width=1, tip_length=0.06)
dist = build_dist(dists[i - 1], Y_DIST)
self.play(FadeIn(dist_arr, scale=0.5), FadeIn(dist, scale=0.8), run_time=0.25)
# 5. Sampling: highlight top bar
top_bar = dist[0][0]
sample_hl = SurroundingRectangle(top_bar, color=YELLOW, stroke_width=1.5, buff=0.02)
samp_lbl = Text("\u2713 argmax", font_size=7, color=YELLOW)
samp_lbl.next_to(dist, DOWN, buff=0.05)
self.play(Create(sample_hl), Write(samp_lbl), run_time=0.2)
self.wait(0.15)
# 6. Predicted token appears below distribution
pred = tok_card(tok, YELLOW, YELLOW)
target_x = SX + len(seq) * GAP
pred.move_to([CX_BLK, Y_DIST - 0.45, 0])
self.play(FadeIn(pred, scale=0.4), run_time=0.15)
# 7. Token rises to join sequence at top
self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3)
self.play(
pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5),
pred[1].animate.set_color(WHITE),
FadeOut(sample_hl), FadeOut(samp_lbl),
)
seq.add(pred)
# 8. KV Cache: add K,V of this predicted token
kv_x = KV_X + i * (kv_size + kv_gap)
k_sq = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4,
stroke_color=YELLOW, stroke_width=0.5).move_to([kv_x, K_Y, 0])
v_sq = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4,
stroke_color=ORANGE, stroke_width=0.5).move_to([kv_x, V_Y, 0])
self.play(FadeIn(k_sq, scale=1.5), FadeIn(v_sq, scale=1.5), run_time=0.15)
kv_k.add(k_sq); kv_v.add(v_sq)
# 9. Remove distribution and arrow
self.play(FadeOut(dist), FadeOut(dist_arr), run_time=0.08)
# 10. Update step label
new_lbl = Text(f"Step {i} — [{input_str}]", font_size=9, color=GRAY)
new_lbl.move_to(step_lbl)
self.play(Transform(step_lbl, new_lbl))
self.wait(0.3)
self.wait(2.0)
def orth_line(start, end, color=GRAY):
"""Create an L-shaped orthogonal line from start to end."""
mid = np.array([start[0], end[1], 0])
path = VMobject(color=color, stroke_width=1.5)
path.set_points_as_corners([start, mid, end])
return path