From 9de0bad3d452d6bffd92b2206d36a28b5eff672e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 25 May 2026 19:19:48 +0800 Subject: [PATCH] fix transformer: GQA text overflow, heatmap sizing, auto-regressive pos labels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Shrink GQA title (42→34) to fit screen - Move GQA annotation from left overflow to right-bottom of V box - Enlarge heatmap cells (0.52→0.65) and labels (12→14, 9→10), lift grid up - Remove Repeat KV section (shorten scene ~2s) - Add position labels to auto-regressive token sequence - Add layer stack effect behind transformer block - Upgrade font sizes and spacing throughout for readability --- transformer.py | 103 +++++++++++++++++++++++++------------------------ 1 file changed, 53 insertions(+), 50 deletions(-) diff --git a/transformer.py b/transformer.py index 037aa59..200b59e 100644 --- a/transformer.py +++ b/transformer.py @@ -15,12 +15,12 @@ class Transformer(Scene): """Animates the GQA attention mechanism with orthogonal connection lines.""" def construct(self): - title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE) + title = Text("Grouped-Query Attention (GQA)", font_size=34, color=BLUE) title.to_edge(UP, buff=0.35) self.play(Write(title)) # ── Helper: box ── - def mk(name, color, w=2.6, h=0.72, fs=10): + def mk(name, color, w=3.0, h=0.85, fs=13): box = Rectangle( width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5 ) @@ -28,32 +28,32 @@ class Transformer(Scene): return VGroup(box, lbl) # ── Layout ── - inp = Text("x (hidden states)", font_size=15, color=GRAY) + inp = Text("x (hidden states)", font_size=20, color=GRAY) inp.move_to(UP * 2.5) y1 = 1.6 q_grp = mk("Q Projection\n1536 → 24×64", YELLOW) k_grp = mk("K Projection\n1536 → 4×64", YELLOW) v_grp = mk("V Projection\n1536 → 4×64", YELLOW) - q_grp.move_to(LEFT * 3.0 + UP * y1) + q_grp.move_to(LEFT * 3.6 + UP * y1) k_grp.move_to(UP * y1) - v_grp.move_to(RIGHT * 3.0 + UP * y1) + v_grp.move_to(RIGHT * 3.6 + UP * y1) y2 = 0.4 - repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10) + repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.8, 0.80) repeat_grp.move_to(UP * y2) y3 = -1.0 sdpa_grp = mk( - "Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10 + "Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 3.2, 0.85, ) sdpa_grp.move_to(UP * y3) y4 = -2.2 - o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10) + o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.6, 0.80) o_grp.move_to(UP * y4) - out = Text("x' (hidden states)", font_size=15, color=GRAY) + out = Text("x' (hidden states)", font_size=20, color=GRAY) out.next_to(o_grp, DOWN, buff=0.4) # ── Animate boxes ── @@ -169,31 +169,18 @@ class Transformer(Scene): VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2 ) gqa_t = Text( - "GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%", - font_size=13, color=YELLOW, + "GQA 6:1\n24 Q-heads → 4 KV-heads\nKV cache -83%", + font_size=11, color=YELLOW, ) - gqa_t.next_to(gqa_h, RIGHT, buff=0.5) + gqa_t.next_to(v_grp, DOWN, buff=0.4).shift(RIGHT * 0.4) self.play(Create(gqa_h), Write(gqa_t)) self.wait(1.8) self.play(FadeOut(gqa_h), FadeOut(gqa_t)) - # ── Repeat KV highlight ── - kv_h = SurroundingRectangle( - VGroup(k_grp, v_grp), color=GREEN, buff=0.12 - ) - kv_t = Text( - "repeat_kv(): broadcast\n4 heads → 24 heads", - font_size=12, color=GREEN, - ) - kv_t.next_to(kv_h, RIGHT, buff=0.5) - self.play(Create(kv_h), Write(kv_t)) - self.wait(1.5) - # ── Fade all ── self.play( *[FadeOut(g) for g in all_boxes], FadeOut(all_lines), - FadeOut(kv_h), FadeOut(kv_t), FadeOut(inp), FadeOut(out), FadeOut(title), ) @@ -268,11 +255,11 @@ class Transformer(Scene): tokens = ["", "The", "cat", "sat", "on", "the", "mat"] n = len(tokens) - cell_size = 0.52 - gap = 0.04 + cell_size = 0.65 + gap = 0.05 grid_high = n * cell_size + (n - 1) * gap grid_left = -grid_high / 2 - grid_top = 1.4 + grid_top = 1.7 # pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased pre_scores = [ @@ -321,7 +308,7 @@ class Transformer(Scene): # row labels (query) on the left row_lbls = VGroup() for i, tok in enumerate(tokens): - lbl = Text(tok, font_size=12, color=GRAY) + lbl = Text(tok, font_size=14, color=GRAY) y = grid_top - i * (cell_size + gap) - cell_size / 2 lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08) row_lbls.add(lbl) @@ -332,7 +319,7 @@ class Transformer(Scene): # column labels (key) on top col_lbls = VGroup() for j, tok in enumerate(tokens): - lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6) + lbl = Text(tok, font_size=10, color=GRAY).rotate(PI / 6) x = grid_left + j * (cell_size + gap) + cell_size / 2 lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04) col_lbls.add(lbl) @@ -379,10 +366,10 @@ class Transformer(Scene): # ═══════════════════════════════════════════════════ # Auto-regressive Generation Demo (v2: full I/O pipeline) def tok_card(text, fill=DARK_BLUE, stroke=GRAY): - t = Text(text, font_size=10, color=WHITE) + t = Text(text, font_size=14, color=WHITE) box = RoundedRectangle( - width=t.width + 0.2, height=t.height + 0.1, - corner_radius=0.04, fill_color=fill, fill_opacity=0.5, + width=t.width + 0.3, height=t.height + 0.16, + corner_radius=0.06, fill_color=fill, fill_opacity=0.5, stroke_color=stroke, stroke_width=0.5, ) t.move_to(box) @@ -399,11 +386,10 @@ class Transformer(Scene): BLK_H = 0.28 TFR_H = 0.55 CX_BLK = -1.0 - Y_EMB = 1.40 - Y_NORM1 = 0.95 - Y_TFR = 0.32 - Y_NORM2 = -0.25 - Y_HEAD = -0.70 + Y_EMB = 1.95 + Y_TFR = 1.10 + Y_NORM2 = 0.35 + Y_HEAD = -0.15 def mkblk(w, h, y, txt, color, fs=10): b = RoundedRectangle(width=w, height=h, corner_radius=0.06, @@ -413,11 +399,19 @@ class Transformer(Scene): return VGroup(b, l).move_to([CX_BLK, y, 0]) emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9) - norm1 = mkblk(BLK_W, BLK_H, Y_NORM1, "RMS Norm", GREEN, 9) - tfr_node = mkblk(2.4, TFR_H, Y_TFR, "Transformer Block\n× 24", PURPLE, 9) + tfr_node = mkblk(2.4, TFR_H, Y_TFR, "Transformer Block\n× 24", PURPLE, 12) norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9) head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9) - pipeline = VGroup(emb_node, norm1, tfr_node, norm2, head) + pipeline = VGroup(emb_node, tfr_node, norm2, head) + + # Stack effect for ×24 layers behind Transformer block + layer_stack = VGroup() + for i in range(3, 0, -1): + shadow = RoundedRectangle( + width=2.4, height=TFR_H, corner_radius=0.06, + stroke_color=BLUE_D, stroke_width=1.0, fill_opacity=0, + ).move_to([CX_BLK + i * 0.04, Y_TFR - i * 0.04, 0]) + layer_stack.add(shadow) # Arrows between blocks arrows = VGroup() @@ -446,9 +440,9 @@ class Transformer(Scene): kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v) # ── Distribution builder ── - def build_dist(probs, y_center, max_w=2.8): + def build_dist(probs, y_center, max_w=3.0): bars = VGroup(); lbls = VGroup() - bh = 0.12; bg = 0.02; lx = CX_BLK - max_w / 2 + bh = 0.18; bg = 0.03; lx = CX_BLK - max_w / 2 items = list(probs.items()) n = len(items) y_top = y_center + (n * bh + (n - 1) * bg) / 2 @@ -459,7 +453,7 @@ class Transformer(Scene): fill_color=interpolate_color(BLUE, RED, pct / 100), fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3) bar.move_to([lx + w / 2, y, 0]) - lbl = Text(f"{tok} {pct}%", font_size=7, color=WHITE) + lbl = Text(f"{tok} {pct}%", font_size=10, color=WHITE) lbl.next_to(bar, RIGHT, buff=0.05) bars.add(bar); lbls.add(lbl) return VGroup(bars, lbls) @@ -472,22 +466,28 @@ class Transformer(Scene): {"the": 60, "a": 12, "top": 8, "floor": 5, "": 15}, {"mat": 50, "table": 15, "chair": 8, "floor": 6, "": 21}, ] - Y_DIST = -1.45 + Y_DIST = -1.65 # ── Token sequence row ── - SX = -4.0; Y_SEQ = 2.0; GAP = 0.55 + SX = -4.0; Y_SEQ = 2.4; GAP = 0.70 seq = VGroup() sos = tok_card("", BLUE, BLUE).move_to([SX, Y_SEQ, 0]) seq.add(sos) + # Position labels under each token + pos_lbls = VGroup() + pos0 = Text("0", font_size=9, color=DARK_GRAY).next_to(sos, DOWN, buff=0.08) + pos_lbls.add(pos0) + # Step label (below everything) step_lbl = Text("Step 0 — [] → ?", font_size=9, color=GRAY) step_lbl.move_to([0, -2.1, 0]) # ── Show static elements ── self.play(FadeIn(pipeline), FadeIn(arrows)) + self.play(FadeIn(layer_stack)) self.play(FadeIn(kv_group)) - self.play(FadeIn(sos)) + self.play(FadeIn(sos), Write(pos0)) self.play(Write(step_lbl)) self.wait(0.5) @@ -506,13 +506,12 @@ class Transformer(Scene): self.play(FadeIn(in_arr, scale=0.5), run_time=0.1) # 3. Cascade through pipeline - pipes = [emb_node, norm1, tfr_node, norm2, head] self.play( - *[p[0].animate.set_fill_opacity(0.6) for p in pipes], + *[p[0].animate.set_fill_opacity(0.6) for p in pipeline], run_time=0.12 ) self.play( - *[p[0].animate.set_fill_opacity(0.25) for p in pipes], + *[p[0].animate.set_fill_opacity(0.25) for p in pipeline], run_time=0.1 ) self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08) @@ -539,12 +538,16 @@ class Transformer(Scene): # 7. Token rises to join sequence at top self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3) + pos_lbl = Text(str(i), font_size=9, color=DARK_GRAY) + pos_lbl.next_to(pred, DOWN, buff=0.08) self.play( pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5), pred[1].animate.set_color(WHITE), + Write(pos_lbl), FadeOut(sample_hl), FadeOut(samp_lbl), ) seq.add(pred) + pos_lbls.add(pos_lbl) # 8. KV Cache: add K,V of this predicted token kv_x = KV_X + i * (kv_size + kv_gap)