fix: shift GQA layout down 0.4 to avoid title-input overlap

This commit is contained in:
ViperEkura 2026-05-07 14:51:31 +08:00
parent 6b26ec33ab
commit 57abefa47f
1 changed files with 6 additions and 6 deletions

View File

@ -27,9 +27,9 @@ class Transformer(Scene):
# ── Layout ──
inp = Text("x (hidden states)", font_size=15, color=GRAY)
inp.move_to(UP * 3.2)
inp.move_to(UP * 2.6)
y1 = 1.9
y1 = 1.5
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
@ -37,17 +37,17 @@ class Transformer(Scene):
k_grp.move_to(UP * y1)
v_grp.move_to(RIGHT * 3.0 + UP * y1)
y2 = 0.4
y2 = 0.0
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
repeat_grp.move_to(UP * y2)
y3 = -1.2
y3 = -1.6
sdpa_grp = mk(
"Scaled Dot-Product\nAttention Q·K^T/√d", BLUE, 2.8, 0.74, 10
)
sdpa_grp.move_to(UP * y3)
y4 = -2.6
y4 = -2.9
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
o_grp.move_to(UP * y4)
@ -516,7 +516,7 @@ class Transformer(Scene):
self.play(FadeOut(in_arr), FadeOut(hl), run_time=0.08)
# 4. Show probability distribution
dist_arr = Arrow(head.get_bottom(), head.get_bottom() + DOWN * 0.4,
dist_arr = Arrow(head.get_bottom(), head.get_bottom() + DOWN * 0.22,
color=GRAY, stroke_width=1, tip_length=0.06)
dist = build_dist(dists[i - 1], Y_DIST)
self.play(FadeIn(dist_arr, scale=0.5), FadeIn(dist, scale=0.8), run_time=0.25)