"""AstrAI promo: Transformer GQA attention animation. Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow lines: Input → Q/K/V Projections → Repeat KV → SDPA → O Projection → Output """ from manim import * import numpy as np import math class Transformer(Scene): """Animates the GQA attention mechanism with orthogonal connection lines.""" def construct(self): title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE) title.to_edge(UP, buff=0.35) self.play(Write(title)) # ── Helper: box ── def mk(name, color, w=2.6, h=0.72, fs=10): box = Rectangle( width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5 ) lbl = Text(name, font_size=fs, color=color) return VGroup(box, lbl) # ── Layout ── inp = Text("x (hidden states)", font_size=15, color=GRAY) inp.move_to(UP * 2.6) y1 = 1.5 q_grp = mk("Q Projection\n1536 → 24×64", YELLOW) k_grp = mk("K Projection\n1536 → 4×64", YELLOW) v_grp = mk("V Projection\n1536 → 4×64", YELLOW) q_grp.move_to(LEFT * 3.0 + UP * y1) k_grp.move_to(UP * y1) v_grp.move_to(RIGHT * 3.0 + UP * y1) y2 = 0.0 repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10) repeat_grp.move_to(UP * y2) y3 = -1.6 sdpa_grp = mk( "Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10 ) sdpa_grp.move_to(UP * y3) y4 = -2.9 o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10) o_grp.move_to(UP * y4) out = Text("x' (hidden states)", font_size=15, color=GRAY) out.next_to(o_grp, DOWN, buff=0.4) # ── Animate boxes ── self.play(Write(inp)) all_boxes = [q_grp, k_grp, v_grp, repeat_grp, sdpa_grp, o_grp] for g in all_boxes: self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2) # ── Input trunk → branch → Q/K/V (enter from directly above) ── trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0]) trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5) self.play(Create(trunk), run_time=0.15) branch_left = Line( np.array([q_grp.get_top()[0], trunk_bottom[1], 0]), np.array([k_grp.get_top()[0], trunk_bottom[1], 0]), color=GRAY, stroke_width=1.5, ) branch_right = Line( np.array([k_grp.get_top()[0], trunk_bottom[1], 0]), np.array([v_grp.get_top()[0], trunk_bottom[1], 0]), color=GRAY, stroke_width=1.5, ) self.play(Create(branch_left), Create(branch_right), run_time=0.2) drop_q = Line( np.array([q_grp.get_top()[0], trunk_bottom[1], 0]), q_grp.get_top(), color=GRAY, stroke_width=1.5, ) drop_k = Line( np.array([k_grp.get_top()[0], trunk_bottom[1], 0]), k_grp.get_top(), color=GRAY, stroke_width=1.5, ) drop_v = Line( np.array([v_grp.get_top()[0], trunk_bottom[1], 0]), v_grp.get_top(), color=GRAY, stroke_width=1.5, ) for ln in [drop_q, drop_k, drop_v]: self.play(Create(ln), run_time=0.12) input_lines = VGroup(trunk, branch_left, branch_right, drop_q, drop_k, drop_v) # ── K/V → Repeat KV (trunk-branch, enter from above) ── kv_junc_y = repeat_grp.get_top()[1] + 0.3 drop_k2 = Line( k_grp.get_bottom(), np.array([k_grp.get_bottom()[0], kv_junc_y, 0]), color=GRAY, stroke_width=1.5, ) drop_v2 = Line( v_grp.get_bottom(), np.array([v_grp.get_bottom()[0], kv_junc_y, 0]), color=GRAY, stroke_width=1.5, ) kv_branch = Line( np.array([v_grp.get_bottom()[0], kv_junc_y, 0]), np.array([k_grp.get_bottom()[0], kv_junc_y, 0]), color=GRAY, stroke_width=1.5, ) kv_trunk = Line( np.array([k_grp.get_bottom()[0], kv_junc_y, 0]), repeat_grp.get_top(), color=GRAY, stroke_width=1.5, ) kv_lines = VGroup(drop_k2, drop_v2, kv_branch, kv_trunk) self.play(Create(kv_lines), run_time=0.3) # ── Q → SDPA (bypasses Repeat KV, from above) ── qs_junc_y = sdpa_grp.get_top()[1] + 0.3 line_qs = VMobject(color=GRAY, stroke_width=1.5) line_qs.set_points_as_corners([ q_grp.get_bottom(), np.array([q_grp.get_bottom()[0], qs_junc_y, 0]), np.array([sdpa_grp.get_top()[0], qs_junc_y, 0]), sdpa_grp.get_top(), ]) self.play(Create(line_qs), run_time=0.15) line_rs = orth_line(repeat_grp.get_bottom(), sdpa_grp.get_top(), GRAY) self.play(Create(line_rs), run_time=0.15) line_so = orth_line(sdpa_grp.get_bottom(), o_grp.get_top(), GRAY) self.play(Create(line_so), run_time=0.15) line_oo = orth_line(o_grp.get_bottom(), out.get_top(), GRAY) self.play(Create(line_oo), run_time=0.15) self.play(Write(out)) self.wait(0.4) all_lines = VGroup( input_lines, kv_lines, line_qs, line_rs, line_so, line_oo, ) # ── RoPE highlight ── rope_q = SurroundingRectangle(q_grp, color=TEAL, buff=0.12) rope_k = SurroundingRectangle(k_grp, color=TEAL, buff=0.12) rope_t = Text( "RoPE: rotary position encoding\napplied to Q and K", font_size=13, color=TEAL, ) rope_t.next_to(VGroup(rope_q, rope_k), UP, buff=0.25) self.play(Create(rope_q), Create(rope_k), Write(rope_t)) self.wait(1.5) self.play(FadeOut(rope_q), FadeOut(rope_k), FadeOut(rope_t)) # ── GQA ratio highlight ── gqa_h = SurroundingRectangle( VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2 ) gqa_t = Text( "GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%", font_size=13, color=YELLOW, ) gqa_t.next_to(gqa_h, RIGHT, buff=0.5) self.play(Create(gqa_h), Write(gqa_t)) self.wait(1.8) self.play(FadeOut(gqa_h), FadeOut(gqa_t)) # ── Repeat KV highlight ── kv_h = SurroundingRectangle( VGroup(k_grp, v_grp), color=GREEN, buff=0.12 ) kv_t = Text( "repeat_kv(): broadcast\n4 heads → 24 heads", font_size=12, color=GREEN, ) kv_t.next_to(kv_h, RIGHT, buff=0.5) self.play(Create(kv_h), Write(kv_t)) self.wait(1.5) # ── Fade all ── self.play( *[FadeOut(g) for g in all_boxes], FadeOut(all_lines), FadeOut(kv_h), FadeOut(kv_t), FadeOut(inp), FadeOut(out), FadeOut(title), ) # ═══════════════════════════════════════════════════ # 12. Scaled Dot-Product Attention — full formula + breakdown # ═══════════════════════════════════════════════════ qkv_title = Text("Scaled Dot-Product Attention", font_size=34, color=BLUE) qkv_title.to_edge(UP, buff=0.35) self.play(Write(qkv_title)) # Full formula first, stays on screen full_eq = MathTex( r"\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\!V", font_size=36, color=WHITE, ) full_eq.next_to(qkv_title, DOWN, buff=0.5) self.play(Write(full_eq)) self.wait(0.6) # Q / K / V — brief meanings q_txt = Text("Q = Query", font_size=18, color=YELLOW) k_txt = Text("K = Key", font_size=18, color=ORANGE) v_txt = Text("V = Value", font_size=18, color=GREEN) qkv_labels = VGroup(q_txt, k_txt, v_txt).arrange(RIGHT, buff=1.5) qkv_labels.next_to(full_eq, DOWN, buff=0.5) self.play(Write(qkv_labels)) q_desc = Text("\"what am I looking for?\"", font_size=11, color=YELLOW) \ .next_to(q_txt, DOWN, buff=0.10) k_desc = Text("\"what do I have?\"", font_size=11, color=ORANGE) \ .next_to(k_txt, DOWN, buff=0.10) v_desc = Text("\"what do I contribute?\"", font_size=11, color=GREEN) \ .next_to(v_txt, DOWN, buff=0.10) self.play(Write(q_desc), Write(k_desc), Write(v_desc)) self.wait(2.0) self.play(FadeOut(qkv_labels), FadeOut(q_desc), FadeOut(k_desc), FadeOut(v_desc)) # Step-by-step decomposition — full formula stays visible above steps = [ (MathTex(r"\text{(1) } S = QK^\top", font_size=26, color=YELLOW), Text("score matrix — pairwise token similarity", font_size=12, color=GRAY)), (MathTex(r"\text{(2) } S \mathbin{/} \sqrt{d_k}", font_size=26, color=ORANGE), Text("scale — prevents gradient explosion", font_size=12, color=GRAY)), (MathTex(r"\text{(3) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right)", font_size=26, color=GREEN), Text("normalize — each row sums to 1 (probability)", font_size=12, color=GRAY)), (MathTex(r"\text{(4) } \operatorname{softmax}\!\left(S \mathbin{/} \sqrt{d_k}\right) \cdot V", font_size=26, color=BLUE), Text("weighted sum — aggregate values by attention", font_size=12, color=GRAY)), ] step_group = VGroup() steps_mobj = VGroup() for eq, desc in steps: sg = VGroup(eq, desc).arrange(DOWN, buff=0.06) step_group.add(sg) steps_mobj.add(eq) step_group.arrange(DOWN, buff=0.22, aligned_edge=LEFT) step_group.next_to(full_eq, DOWN, buff=0.6) for sg in step_group: self.play(Write(sg), run_time=0.3) self.wait(2.5) self.play(FadeOut(step_group), FadeOut(full_eq)) # ═══════════════════════════════════════════════════ # 13. Attention score heatmap — concrete example # ═══════════════════════════════════════════════════ hm_title = Text("Attention Score Heatmap", font_size=34, color=BLUE) hm_title.to_edge(UP, buff=0.35) hm_sub = Text("\"The cat sat on the mat\" — causal, per-token attention weights", font_size=14, color=GRAY).next_to(hm_title, DOWN, buff=0.12) self.play(FadeOut(qkv_title), Write(hm_title), Write(hm_sub)) tokens = ["", "The", "cat", "sat", "on", "the", "mat"] n = len(tokens) cell_size = 0.52 gap = 0.04 grid_high = n * cell_size + (n - 1) * gap grid_left = -grid_high / 2 grid_top = 1.4 # pre-mask raw scores (QK^T / sqrt(d_k)) — random-varied, distance-biased pre_scores = [ [2.8, 1.5, 0.3, 0.1, 0.0, 0.0, 0.0], [1.2, 3.5, 1.8, 0.5, 0.2, 0.1, 0.0], [0.4, 2.0, 3.0, 1.5, 0.6, 0.2, 0.1], [0.1, 0.6, 2.5, 2.8, 1.2, 0.4, 0.1], [0.0, 0.2, 0.8, 2.0, 2.5, 1.5, 0.3], [0.0, 0.1, 0.3, 0.9, 1.8, 2.5, 1.2], [0.0, 0.0, 0.1, 0.4, 0.8, 1.5, 3.0], ] # compute post-softmax weights with causal mask (j > i → -inf) post_weights = [] for i in range(n): row = pre_scores[i] masked = [-float('inf') if j > i else row[j] for j in range(n)] exps = [math.exp(v) for v in masked] exp_sum = sum(exps) post_weights.append([e / exp_sum for e in exps]) flat_pre = [w for row in pre_scores for w in row] pre_min, pre_max = min(flat_pre), max(flat_pre) flat_post = [w for row in post_weights for w in row] post_min, post_max = min(flat_post), max(flat_post) cells = VGroup() masked_cells = VGroup() for i in range(n): for j in range(n): pw = pre_scores[i][j] pw_normed = (pw - pre_min) / (pre_max - pre_min) pw_color = interpolate_color(BLUE, RED, pw_normed) sq = Square( side_length=cell_size, fill_color=pw_color, fill_opacity=0.75, stroke_width=0.5, stroke_color=GRAY, ) x = grid_left + j * (cell_size + gap) + cell_size / 2 y = grid_top - i * (cell_size + gap) - cell_size / 2 sq.move_to([x, y, 0]) cells.add(sq) if j > i: masked_cells.add(sq) self.play(FadeIn(sq, scale=0.6), run_time=0.015) # row labels (query) on the left row_lbls = VGroup() for i, tok in enumerate(tokens): lbl = Text(tok, font_size=12, color=GRAY) y = grid_top - i * (cell_size + gap) - cell_size / 2 lbl.next_to([grid_left - 0.15, y, 0], LEFT, buff=0.08) row_lbls.add(lbl) q_label = Text("Q", font_size=11, color=WHITE, weight=BOLD) q_label.move_to(row_lbls[0].get_left() + LEFT * 0.3).shift(UP * 0.15) self.play(*[Write(l) for l in row_lbls], Write(q_label)) # column labels (key) on top col_lbls = VGroup() for j, tok in enumerate(tokens): lbl = Text(tok, font_size=9, color=GRAY).rotate(PI / 6) x = grid_left + j * (cell_size + gap) + cell_size / 2 lbl.next_to([x, grid_top + 0.06, 0], UP, buff=0.04) col_lbls.add(lbl) k_label = Text("K", font_size=11, color=WHITE, weight=BOLD) k_label.next_to(col_lbls[0], UP, buff=0.06) self.play(*[Write(l) for l in col_lbls], Write(k_label)) self.wait(1.0) # causal mask + softmax — zero out future tokens, recompute weights causal_txt = Text("causal mask + softmax\n(future tokens → 0)", font_size=11, color=RED) \ .next_to(cells[n - 1], UP, buff=0.25).align_to(cells[n - 1], RIGHT) anims = [sq.animate.set_fill(DARK_GRAY, 0.15) for sq in masked_cells] for i in range(n): for j in range(n): if j <= i: idx = i * n + j aw = post_weights[i][j] aw_normed = (aw - post_min) / (post_max - post_min) aw_color = interpolate_color(BLUE, RED, aw_normed) anims.append(cells[idx].animate.set_fill(aw_color, 0.75)) self.play(*anims, Write(causal_txt)) self.wait(1.2) self.play(FadeOut(causal_txt)) # highlight key patterns h1 = SurroundingRectangle(cells[2 * n + 1], color=ORANGE, stroke_width=2, buff=0.04) h2 = SurroundingRectangle(cells[3 * n + 2], color=ORANGE, stroke_width=2, buff=0.04) h3 = SurroundingRectangle(cells[4 * n + 3], color=ORANGE, stroke_width=2, buff=0.04) h4 = SurroundingRectangle(cells[5 * n + 4], color=ORANGE, stroke_width=2, buff=0.04) h5 = SurroundingRectangle(cells[6 * n + 5], color=ORANGE, stroke_width=2, buff=0.04) hl_text = Text("previous token attends to next\n(causal sequence learning)", font_size=11, color=ORANGE) \ .next_to(cells[(n - 1) * n + (n - 1)], RIGHT, buff=0.8) self.play(Create(h1), Create(h2), Create(h3), Create(h4), Create(h5), Write(hl_text)) self.wait(2.0) self.play(FadeOut(h1), FadeOut(h2), FadeOut(h3), FadeOut(h4), FadeOut(h5), FadeOut(hl_text)) # fade all heatmap self.play( FadeOut(hm_title), FadeOut(hm_sub), FadeOut(cells), FadeOut(row_lbls), FadeOut(col_lbls), FadeOut(q_label), FadeOut(k_label), ) # ═══════════════════════════════════════════════════ # Auto-regressive Generation Demo (v2: full I/O pipeline) def tok_card(text, fill=DARK_BLUE, stroke=GRAY): t = Text(text, font_size=10, color=WHITE, weight=BOLD) box = RoundedRectangle( width=t.width + 0.2, height=t.height + 0.1, corner_radius=0.04, fill_color=fill, fill_opacity=0.5, stroke_color=stroke, stroke_width=0.5, ) t.move_to(box) return VGroup(box, t) gen_tokens = ["", "The", "cat", "sat", "on", "the", "mat"] gen_title = Text("Auto-regressive Generation", font_size=34, color=BLUE) gen_title.to_edge(UP, buff=0.35) self.play(Write(gen_title)) # ── Layout constants ── BLK_W = 1.8 BLK_H = 0.28 TFR_H = 0.48 CX_BLK = -1.0 Y_EMB = 1.40 Y_NORM1 = 0.95 Y_TFR = 0.32 Y_NORM2 = -0.25 Y_HEAD = -0.70 def mkblk(w, h, y, txt, color, fs=10): b = RoundedRectangle(width=w, height=h, corner_radius=0.06, fill_color=DARK_BLUE, fill_opacity=0.25, stroke_color=color, stroke_width=1.5) l = Text(txt, font_size=fs, color=color, weight=BOLD).move_to(b) return VGroup(b, l).move_to([CX_BLK, y, 0]) emb_node = mkblk(BLK_W, BLK_H, Y_EMB, "Embedding", YELLOW, 9) norm1 = mkblk(BLK_W, BLK_H, Y_NORM1, "RMS Norm", GREEN, 9) tfr_node = mkblk(BLK_W, TFR_H, Y_TFR, "Transformer", PURPLE, 11) norm2 = mkblk(BLK_W, BLK_H, Y_NORM2, "RMS Norm", GREEN, 9) head = mkblk(BLK_W, BLK_H, Y_HEAD, "LM Head", RED, 9) pipeline = VGroup(emb_node, norm1, tfr_node, norm2, head) # Arrows between blocks arrows = VGroup() for a, b in zip(pipeline[:-1], pipeline[1:]): arr = Arrow(a.get_bottom(), b.get_top(), color=GRAY, stroke_width=1.2, tip_length=0.07) arrows.add(arr) # ── KV Cache (right of Transformer) ── KV_X = 1.8 kv_size = 0.16 kv_gap = 0.04 K_Y = Y_TFR + 0.05 V_Y = Y_TFR - 0.22 cache_lbl = Text("KV Cache", font_size=8, color=GRAY).move_to([KV_X, Y_TFR + TFR_H / 2 + 0.2, 0]) k_hdr = Text("K:", font_size=7, color=YELLOW).move_to([KV_X - 0.65, K_Y, 0]) v_hdr = Text("V:", font_size=7, color=ORANGE).move_to([KV_X - 0.65, V_Y, 0]) kv_k = VGroup() kv_v = VGroup() k0 = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4, stroke_color=YELLOW, stroke_width=0.5).move_to([KV_X, K_Y, 0]) v0 = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4, stroke_color=ORANGE, stroke_width=0.5).move_to([KV_X, V_Y, 0]) kv_k.add(k0); kv_v.add(v0) kv_group = VGroup(cache_lbl, k_hdr, v_hdr, kv_k, kv_v) # ── Distribution builder ── def build_dist(probs, y_center, max_w=2.8): bars = VGroup(); lbls = VGroup() bh = 0.12; bg = 0.02; lx = CX_BLK - max_w / 2 items = list(probs.items()) n = len(items) y_top = y_center + (n * bh + (n - 1) * bg) / 2 for i, (tok, pct) in enumerate(items): w = max_w * pct / 100 y = y_top - i * (bh + bg) bar = Rectangle(width=w, height=bh, fill_color=interpolate_color(BLUE, RED, pct / 100), fill_opacity=0.85, stroke_color=LIGHT_GRAY, stroke_width=0.3) bar.move_to([lx + w / 2, y, 0]) lbl = Text(f"{tok} {pct}%", font_size=7, color=WHITE) lbl.next_to(bar, RIGHT, buff=0.05) bars.add(bar); lbls.add(lbl) return VGroup(bars, lbls) dists = [ {"The": 72, "cat": 8, "sat": 6, "on": 4, "": 10}, {"cat": 65, "sat": 12, "was": 8, "is": 6, "": 9}, {"sat": 58, "slept": 15, "ran": 8, "jumped": 6, "": 13}, {"on": 55, "down": 12, "quietly": 8, "and": 6, "": 19}, {"the": 60, "a": 12, "top": 8, "floor": 5, "": 15}, {"mat": 50, "table": 15, "chair": 8, "floor": 6, "": 21}, ] Y_DIST = -1.45 # ── Token sequence row ── SX = -4.0; Y_SEQ = 2.0; GAP = 0.55 seq = VGroup() sos = tok_card("", BLUE, BLUE).move_to([SX, Y_SEQ, 0]) seq.add(sos) # Step label (below everything) step_lbl = Text("Step 0 — [] → ?", font_size=9, color=GRAY) step_lbl.move_to([0, -2.1, 0]) # ── Show static elements ── self.play(FadeIn(pipeline), FadeIn(arrows)) self.play(FadeIn(kv_group)) self.play(FadeIn(sos)) self.play(Write(step_lbl)) self.wait(0.5) # ── Generation loop ── for i, tok in enumerate(gen_tokens[1:], start=1): input_str = " ".join(gen_tokens[:i + 1]) # 1. Highlight last token (the input being processed) last = seq[-1] hl = SurroundingRectangle(last, color=YELLOW, stroke_width=2, buff=0.04) self.play(Create(hl), run_time=0.12) # 2. Arrow from last token → Embedding in_arr = Arrow(last.get_bottom(), emb_node.get_top(), color=YELLOW, stroke_width=2, tip_length=0.08) self.play(FadeIn(in_arr, scale=0.5), run_time=0.1) # 3. Cascade through pipeline pipes = [emb_node, norm1, tfr_node, norm2, head] self.play( *[p[0].animate.set_fill_opacity(0.6) for p in pipes], run_time=0.12 ) self.play( *[p[0].animate.set_fill_opacity(0.25) for p in pipes], run_time=0.1 ) self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08) # 4. Show probability distribution dist_arr = Arrow(head.get_bottom(), head.get_bottom() + DOWN * 0.22, color=GRAY, stroke_width=1, tip_length=0.06) dist = build_dist(dists[i - 1], Y_DIST) self.play(FadeIn(dist_arr, scale=0.5), FadeIn(dist, scale=0.8), run_time=0.25) # 5. Sampling: highlight top bar top_bar = dist[0][0] sample_hl = SurroundingRectangle(top_bar, color=YELLOW, stroke_width=1.5, buff=0.02) samp_lbl = Text("\u2713 argmax", font_size=7, color=YELLOW) samp_lbl.next_to(dist, DOWN, buff=0.05) self.play(Create(sample_hl), Write(samp_lbl), run_time=0.2) self.wait(0.15) # 6. Predicted token appears below distribution pred = tok_card(tok, YELLOW, YELLOW) target_x = SX + len(seq) * GAP pred.move_to([CX_BLK, Y_DIST - 0.45, 0]) self.play(FadeIn(pred, scale=0.4), run_time=0.15) # 7. Token rises to join sequence at top self.play(pred.animate.move_to([target_x, Y_SEQ, 0]), run_time=0.3) self.play( pred[0].animate.set_fill(DARK_BLUE, 0.5).set_stroke(GRAY, 0.5), pred[1].animate.set_color(WHITE), FadeOut(sample_hl), FadeOut(samp_lbl), ) seq.add(pred) # 8. KV Cache: add K,V of this predicted token kv_x = KV_X + i * (kv_size + kv_gap) k_sq = Square(kv_size, fill_color=YELLOW, fill_opacity=0.4, stroke_color=YELLOW, stroke_width=0.5).move_to([kv_x, K_Y, 0]) v_sq = Square(kv_size, fill_color=ORANGE, fill_opacity=0.4, stroke_color=ORANGE, stroke_width=0.5).move_to([kv_x, V_Y, 0]) self.play(FadeIn(k_sq, scale=1.5), FadeIn(v_sq, scale=1.5), run_time=0.15) kv_k.add(k_sq); kv_v.add(v_sq) # 9. Remove distribution and arrow self.play(FadeOut(dist), FadeOut(dist_arr), run_time=0.08) # 10. Update step label new_lbl = Text(f"Step {i} — [{input_str}]", font_size=9, color=GRAY) new_lbl.move_to(step_lbl) self.play(Transform(step_lbl, new_lbl)) self.wait(0.3) self.wait(2.0) def orth_line(start, end, color=GRAY): """Create an L-shaped orthogonal line from start to end.""" mid = np.array([start[0], end[1], 0]) path = VMobject(color=color, stroke_width=1.5) path.set_points_as_corners([start, mid, end]) return path