fix: architecture layout and rewrite continuous_batching with dynamic pipeline animation

This commit is contained in:
ViperEkura 2026-05-06 21:48:18 +08:00
parent c03abd31fe
commit a7a79eef96
3 changed files with 137 additions and 72 deletions

View File

@ -58,8 +58,8 @@ class Architecture(Scene):
self.wait(0.5) self.wait(0.5)
hl = SurroundingRectangle(layers[3], color=GREEN, buff=0.12) hl = SurroundingRectangle(layers[3], color=GREEN, buff=0.12)
hl_note = Text("Zero-Copy Prefix Reuse", font_size=22, color=GREEN) hl_note = Text("Zero-Copy Prefix Reuse", font_size=18, color=GREEN)
hl_note.next_to(hl, RIGHT, buff=0.8) hl_note.next_to(hl, LEFT, buff=0.4)
self.play(Create(hl), Write(hl_note)) self.play(Create(hl), Write(hl_note))
self.wait(1.5) self.wait(1.5)
self.play(FadeOut(hl), FadeOut(hl_note)) self.play(FadeOut(hl), FadeOut(hl_note))

View File

@ -1,98 +1,158 @@
"""AstrAI promo: Continuous Batching animation. """AstrAI promo: Continuous Batching animation.
Shows how tasks flow through the 4-phase pipeline and get batched together. Shows 4-phase pipeline with multiple requests concurrently at different stages,
and position-grouped decode batching the key advantage over static batching.
""" """
from manim import * from manim import *
class ContinuousBatching(Scene): class ContinuousBatching(Scene):
"""Animates tasks flowing through the prefill->decode pipeline."""
def construct(self): def construct(self):
# ── title ──
title = Text("Continuous Batching", font_size=48, color=BLUE) title = Text("Continuous Batching", font_size=48, color=BLUE)
self.play(Write(title)) self.play(Write(title))
self.wait(0.5) self.wait(0.3)
self.play(title.animate.to_edge(UP).scale(0.6)) self.play(title.animate.to_edge(UP).scale(0.55))
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN)
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2)
self.play(Create(top_bar)) self.play(Create(top_bar))
# ── pipeline stages ── # ── 4-phase loop (vertical) ──
stage_names = ["Waiting\nQueue", "Prefill", "Decode\n(Batched)", "Finished"] phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"]
stage_color = [GRAY, BLUE, YELLOW, GREEN] phase_colors = [GRAY, ORANGE, BLUE, YELLOW]
phase_descs = [
"Evict finished slots",
"Admit new requests",
"Compute KV cache",
"Group by position",
]
stages = VGroup() phases = VGroup()
arrows = VGroup() phase_arrows = VGroup()
for i, (name, color) in enumerate(zip(stage_names, stage_color)): for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)):
box = Rectangle(height=1.5, width=2.5, color=color, fill_opacity=0.12) box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12)
lbl = Text(name, font_size=18, color=color) lbl = Text(name, font_size=18, color=color)
grp = VGroup(box, lbl) grp = VGroup(box, lbl)
grp.shift(RIGHT * (i - 1.5) * 3.2 + DOWN * 0.5) phases.add(grp)
stages.add(grp)
self.play(Create(grp), run_time=0.35)
if i > 0: if i > 0:
a = Arrow(stages[i - 1].get_right(), stages[i].get_left(), color=GRAY) a = Arrow(
arrows.add(a) phases[i - 1].get_bottom(),
self.play(Create(a), run_time=0.2) phases[i].get_top(),
color=GRAY, buff=0.08,
pipeline = VGroup(stages, arrows) max_tip_length_to_length_ratio=0.2,
plabel = Text("4-Phase Generation Loop", font_size=16, color=GRAY).next_to(
pipeline, DOWN, buff=0.4
) )
self.play(Write(plabel)) phase_arrows.add(a)
self.wait(0.5)
# ── spawn tasks ── phases.arrange(DOWN, buff=0.25)
task_colors = [YELLOW, ORANGE, PINK, TEAL, GREEN] phases.shift(LEFT * 3.5 + DOWN * 0.6)
tasks = VGroup()
box_center = stages[0].get_center() for i in range(4):
for i, c in enumerate(task_colors): self.play(Create(phases[i]))
dot = Dot(color=c, radius=0.12) if i > 0:
y_off = (i - 2) * 0.2 self.play(Create(phase_arrows[i - 1]))
dot.move_to(box_center + RIGHT * y_off * 0.3) self.wait(0.3)
lbl = Text(f"R{i+1}", font_size=10, color=c).next_to(dot, UP, buff=0.1)
tg = VGroup(dot, lbl) # cycle arrow back from Decode to Cleanup
tasks.add(tg) loop_arrow = CurvedArrow(
self.play(FadeIn(tg, scale=0.5), run_time=0.12) phases[-1].get_right() + RIGHT * 0.15,
phases[0].get_right() + RIGHT * 0.15,
color=GRAY, angle=PI / 2,
)
loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1)
self.play(Create(loop_arrow), Write(loop_label))
self.wait(0.3)
# ── animate requests at different stages concurrently ──
colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE]
requests = []
def make_req(name, color):
dot = Dot(color=color, radius=0.13)
lbl = Text(name, font_size=14, color=color)
lbl.next_to(dot, LEFT, buff=0.15)
return VGroup(dot, lbl)
# R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!)
r1 = make_req("R1", colors[0])
r1.next_to(phases[2], RIGHT, buff=1.2)
r2 = make_req("R2", colors[1])
r2.next_to(phases[3], RIGHT, buff=1.2)
r3 = make_req("R3", colors[2])
r3.next_to(phases[1], RIGHT, buff=1.2)
r4 = make_req("R4", colors[3])
r4.next_to(phases[0], RIGHT, buff=1.2)
self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7),
FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7))
self.wait(0.4)
concurrent_note = Text("3 requests at different phases simultaneously",
font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5)
self.play(Write(concurrent_note))
self.wait(1.2)
self.play(FadeOut(concurrent_note))
# ── animate rotation through phases ──
# R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill
self.play(
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
r2.animate.next_to(phases[0], RIGHT, buff=1.2),
r3.animate.next_to(phases[2], RIGHT, buff=1.2),
r4.animate.next_to(phases[1], RIGHT, buff=1.2),
)
self.wait(0.3)
# R2 done -> spawn R5
new_r5 = make_req("R5", colors[4])
new_r5.next_to(phases[1], RIGHT, buff=1.2)
self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7))
r2_target = new_r5
self.wait(0.3) self.wait(0.3)
# ── animate through stages ── # ── highlight decode: position-grouped batching ──
for phase in range(1, 4): # gather multiple requests into decode
target = stages[phase].get_center()
anims = [t.animate.move_to(target) for t in tasks]
self.play(*anims, run_time=0.5, rate_func=smooth)
self.wait(0.15)
# ── highlight decode batching ──
ring = SurroundingRectangle(stages[2], color=YELLOW, buff=0.12)
note = Text(
"Same-position batch decoding", font_size=16, color=YELLOW
).next_to(stages[2], DOWN, buff=0.5)
self.play(Create(ring), Write(note))
self.wait(1)
self.play(FadeOut(ring), FadeOut(note))
# ── throughput comparison (text) ──
self.play( self.play(
*[FadeOut(t) for t in tasks], r1.animate.next_to(phases[3], RIGHT, buff=1.2),
FadeOut(pipeline), r3.animate.next_to(phases[3], RIGHT, buff=2.5),
FadeOut(plabel), FadeOut(r4), FadeOut(new_r5),
FadeOut(top_bar),
) )
self.wait(0.2)
ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12)
ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch",
font_size=15, color=YELLOW, line_spacing=0.6)
ring_note.next_to(ring, DOWN, buff=0.35)
hbox = SurroundingRectangle(
VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1),
color=GREEN, buff=0.2,
)
hbox.next_to(phases[3], RIGHT, buff=1.85)
self.play(Create(ring), Write(ring_note))
self.play(Create(hbox))
self.wait(1.8)
self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox),
FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label))
# O(1) slot allocation highlight
bitmask_box = VGroup(
Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE),
Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY),
).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8)
self.play(Write(bitmask_box))
self.wait(1.2)
# ── clear for throughput ──
self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar])
compare = VGroup( compare = VGroup(
Text("Throughput Comparison", font_size=32, color=BLUE), Text("Throughput Comparison", font_size=32, color=BLUE),
Text( Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED),
"Static Batch: 1.0× (baseline)", Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN),
font_size=24, color=RED,
),
Text(
"Continuous Batching: 3.4× (single GPU)",
font_size=24, color=GREEN,
),
).arrange(DOWN, buff=0.4, aligned_edge=LEFT) ).arrange(DOWN, buff=0.4, aligned_edge=LEFT)
self.play(Write(compare)) self.play(Write(compare))
self.wait(2) self.wait(2)
self.play(FadeOut(compare)) self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar))

View File

@ -2,6 +2,9 @@
import subprocess import subprocess
import sys import sys
from pathlib import Path
ROOT = Path(__file__).parent
SCENES = [ SCENES = [
("transformer.py", "Transformer"), ("transformer.py", "Transformer"),
@ -12,19 +15,21 @@ SCENES = [
def render(file_name, scene_name, quality="-qh"): def render(file_name, scene_name, quality="-qh"):
script_path = ROOT / file_name
media_dir = ROOT / "output"
cmd = [ cmd = [
sys.executable, sys.executable,
"-m", "-m",
"manim", "manim",
f"promo/{file_name}", str(script_path),
scene_name, scene_name,
quality, quality,
"--media_dir", "--media_dir",
"promo/output", str(media_dir),
] ]
print(f"Rendering {scene_name}...") print(f"Rendering {scene_name}...")
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
print(f" Done → promo/output/{scene_name}.mp4") print(f" Done → {media_dir / 'videos' / scene_name.lower()}.mp4")
if __name__ == "__main__": if __name__ == "__main__":