fix: architecture layout and rewrite continuous_batching with dynamic pipeline animation
This commit is contained in:
parent
c03abd31fe
commit
a7a79eef96
|
|
@ -58,8 +58,8 @@ class Architecture(Scene):
|
|||
self.wait(0.5)
|
||||
|
||||
hl = SurroundingRectangle(layers[3], color=GREEN, buff=0.12)
|
||||
hl_note = Text("Zero-Copy Prefix Reuse", font_size=22, color=GREEN)
|
||||
hl_note.next_to(hl, RIGHT, buff=0.8)
|
||||
hl_note = Text("Zero-Copy Prefix Reuse", font_size=18, color=GREEN)
|
||||
hl_note.next_to(hl, LEFT, buff=0.4)
|
||||
self.play(Create(hl), Write(hl_note))
|
||||
self.wait(1.5)
|
||||
self.play(FadeOut(hl), FadeOut(hl_note))
|
||||
|
|
|
|||
|
|
@ -1,98 +1,158 @@
|
|||
"""AstrAI promo: Continuous Batching animation.
|
||||
|
||||
Shows how tasks flow through the 4-phase pipeline and get batched together.
|
||||
Shows 4-phase pipeline with multiple requests concurrently at different stages,
|
||||
and position-grouped decode batching — the key advantage over static batching.
|
||||
"""
|
||||
|
||||
from manim import *
|
||||
|
||||
|
||||
class ContinuousBatching(Scene):
|
||||
"""Animates tasks flowing through the prefill->decode pipeline."""
|
||||
|
||||
def construct(self):
|
||||
# ── title ──
|
||||
title = Text("Continuous Batching", font_size=48, color=BLUE)
|
||||
self.play(Write(title))
|
||||
self.wait(0.5)
|
||||
self.play(title.animate.to_edge(UP).scale(0.6))
|
||||
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN)
|
||||
self.wait(0.3)
|
||||
self.play(title.animate.to_edge(UP).scale(0.55))
|
||||
|
||||
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2)
|
||||
self.play(Create(top_bar))
|
||||
|
||||
# ── pipeline stages ──
|
||||
stage_names = ["Waiting\nQueue", "Prefill", "Decode\n(Batched)", "Finished"]
|
||||
stage_color = [GRAY, BLUE, YELLOW, GREEN]
|
||||
# ── 4-phase loop (vertical) ──
|
||||
phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"]
|
||||
phase_colors = [GRAY, ORANGE, BLUE, YELLOW]
|
||||
phase_descs = [
|
||||
"Evict finished slots",
|
||||
"Admit new requests",
|
||||
"Compute KV cache",
|
||||
"Group by position",
|
||||
]
|
||||
|
||||
stages = VGroup()
|
||||
arrows = VGroup()
|
||||
for i, (name, color) in enumerate(zip(stage_names, stage_color)):
|
||||
box = Rectangle(height=1.5, width=2.5, color=color, fill_opacity=0.12)
|
||||
phases = VGroup()
|
||||
phase_arrows = VGroup()
|
||||
for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)):
|
||||
box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12)
|
||||
lbl = Text(name, font_size=18, color=color)
|
||||
grp = VGroup(box, lbl)
|
||||
grp.shift(RIGHT * (i - 1.5) * 3.2 + DOWN * 0.5)
|
||||
stages.add(grp)
|
||||
self.play(Create(grp), run_time=0.35)
|
||||
phases.add(grp)
|
||||
if i > 0:
|
||||
a = Arrow(stages[i - 1].get_right(), stages[i].get_left(), color=GRAY)
|
||||
arrows.add(a)
|
||||
self.play(Create(a), run_time=0.2)
|
||||
|
||||
pipeline = VGroup(stages, arrows)
|
||||
plabel = Text("4-Phase Generation Loop", font_size=16, color=GRAY).next_to(
|
||||
pipeline, DOWN, buff=0.4
|
||||
a = Arrow(
|
||||
phases[i - 1].get_bottom(),
|
||||
phases[i].get_top(),
|
||||
color=GRAY, buff=0.08,
|
||||
max_tip_length_to_length_ratio=0.2,
|
||||
)
|
||||
self.play(Write(plabel))
|
||||
self.wait(0.5)
|
||||
phase_arrows.add(a)
|
||||
|
||||
# ── spawn tasks ──
|
||||
task_colors = [YELLOW, ORANGE, PINK, TEAL, GREEN]
|
||||
tasks = VGroup()
|
||||
box_center = stages[0].get_center()
|
||||
for i, c in enumerate(task_colors):
|
||||
dot = Dot(color=c, radius=0.12)
|
||||
y_off = (i - 2) * 0.2
|
||||
dot.move_to(box_center + RIGHT * y_off * 0.3)
|
||||
lbl = Text(f"R{i+1}", font_size=10, color=c).next_to(dot, UP, buff=0.1)
|
||||
tg = VGroup(dot, lbl)
|
||||
tasks.add(tg)
|
||||
self.play(FadeIn(tg, scale=0.5), run_time=0.12)
|
||||
phases.arrange(DOWN, buff=0.25)
|
||||
phases.shift(LEFT * 3.5 + DOWN * 0.6)
|
||||
|
||||
for i in range(4):
|
||||
self.play(Create(phases[i]))
|
||||
if i > 0:
|
||||
self.play(Create(phase_arrows[i - 1]))
|
||||
self.wait(0.3)
|
||||
|
||||
# cycle arrow back from Decode to Cleanup
|
||||
loop_arrow = CurvedArrow(
|
||||
phases[-1].get_right() + RIGHT * 0.15,
|
||||
phases[0].get_right() + RIGHT * 0.15,
|
||||
color=GRAY, angle=PI / 2,
|
||||
)
|
||||
loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1)
|
||||
self.play(Create(loop_arrow), Write(loop_label))
|
||||
self.wait(0.3)
|
||||
|
||||
# ── animate requests at different stages concurrently ──
|
||||
colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE]
|
||||
requests = []
|
||||
|
||||
def make_req(name, color):
|
||||
dot = Dot(color=color, radius=0.13)
|
||||
lbl = Text(name, font_size=14, color=color)
|
||||
lbl.next_to(dot, LEFT, buff=0.15)
|
||||
return VGroup(dot, lbl)
|
||||
|
||||
# R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!)
|
||||
r1 = make_req("R1", colors[0])
|
||||
r1.next_to(phases[2], RIGHT, buff=1.2)
|
||||
r2 = make_req("R2", colors[1])
|
||||
r2.next_to(phases[3], RIGHT, buff=1.2)
|
||||
r3 = make_req("R3", colors[2])
|
||||
r3.next_to(phases[1], RIGHT, buff=1.2)
|
||||
|
||||
r4 = make_req("R4", colors[3])
|
||||
r4.next_to(phases[0], RIGHT, buff=1.2)
|
||||
|
||||
self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7),
|
||||
FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7))
|
||||
self.wait(0.4)
|
||||
|
||||
concurrent_note = Text("3 requests at different phases simultaneously",
|
||||
font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5)
|
||||
self.play(Write(concurrent_note))
|
||||
self.wait(1.2)
|
||||
self.play(FadeOut(concurrent_note))
|
||||
|
||||
# ── animate rotation through phases ──
|
||||
# R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill
|
||||
self.play(
|
||||
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
|
||||
r2.animate.next_to(phases[0], RIGHT, buff=1.2),
|
||||
r3.animate.next_to(phases[2], RIGHT, buff=1.2),
|
||||
r4.animate.next_to(phases[1], RIGHT, buff=1.2),
|
||||
)
|
||||
self.wait(0.3)
|
||||
|
||||
# R2 done -> spawn R5
|
||||
new_r5 = make_req("R5", colors[4])
|
||||
new_r5.next_to(phases[1], RIGHT, buff=1.2)
|
||||
self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7))
|
||||
r2_target = new_r5
|
||||
|
||||
self.wait(0.3)
|
||||
|
||||
# ── animate through stages ──
|
||||
for phase in range(1, 4):
|
||||
target = stages[phase].get_center()
|
||||
anims = [t.animate.move_to(target) for t in tasks]
|
||||
self.play(*anims, run_time=0.5, rate_func=smooth)
|
||||
self.wait(0.15)
|
||||
|
||||
# ── highlight decode batching ──
|
||||
ring = SurroundingRectangle(stages[2], color=YELLOW, buff=0.12)
|
||||
note = Text(
|
||||
"Same-position batch decoding", font_size=16, color=YELLOW
|
||||
).next_to(stages[2], DOWN, buff=0.5)
|
||||
self.play(Create(ring), Write(note))
|
||||
self.wait(1)
|
||||
self.play(FadeOut(ring), FadeOut(note))
|
||||
|
||||
# ── throughput comparison (text) ──
|
||||
# ── highlight decode: position-grouped batching ──
|
||||
# gather multiple requests into decode
|
||||
self.play(
|
||||
*[FadeOut(t) for t in tasks],
|
||||
FadeOut(pipeline),
|
||||
FadeOut(plabel),
|
||||
FadeOut(top_bar),
|
||||
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
|
||||
r3.animate.next_to(phases[3], RIGHT, buff=2.5),
|
||||
FadeOut(r4), FadeOut(new_r5),
|
||||
)
|
||||
self.wait(0.2)
|
||||
|
||||
ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12)
|
||||
ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch",
|
||||
font_size=15, color=YELLOW, line_spacing=0.6)
|
||||
ring_note.next_to(ring, DOWN, buff=0.35)
|
||||
hbox = SurroundingRectangle(
|
||||
VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1),
|
||||
color=GREEN, buff=0.2,
|
||||
)
|
||||
hbox.next_to(phases[3], RIGHT, buff=1.85)
|
||||
|
||||
self.play(Create(ring), Write(ring_note))
|
||||
self.play(Create(hbox))
|
||||
self.wait(1.8)
|
||||
self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox),
|
||||
FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label))
|
||||
|
||||
# O(1) slot allocation highlight
|
||||
bitmask_box = VGroup(
|
||||
Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE),
|
||||
Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY),
|
||||
).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8)
|
||||
|
||||
self.play(Write(bitmask_box))
|
||||
self.wait(1.2)
|
||||
|
||||
# ── clear for throughput ──
|
||||
self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar])
|
||||
|
||||
compare = VGroup(
|
||||
Text("Throughput Comparison", font_size=32, color=BLUE),
|
||||
Text(
|
||||
"Static Batch: 1.0× (baseline)",
|
||||
font_size=24, color=RED,
|
||||
),
|
||||
Text(
|
||||
"Continuous Batching: 3.4× (single GPU)",
|
||||
font_size=24, color=GREEN,
|
||||
),
|
||||
Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED),
|
||||
Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN),
|
||||
).arrange(DOWN, buff=0.4, aligned_edge=LEFT)
|
||||
self.play(Write(compare))
|
||||
self.wait(2)
|
||||
self.play(FadeOut(compare))
|
||||
self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar))
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).parent
|
||||
|
||||
SCENES = [
|
||||
("transformer.py", "Transformer"),
|
||||
|
|
@ -12,19 +15,21 @@ SCENES = [
|
|||
|
||||
|
||||
def render(file_name, scene_name, quality="-qh"):
|
||||
script_path = ROOT / file_name
|
||||
media_dir = ROOT / "output"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"manim",
|
||||
f"promo/{file_name}",
|
||||
str(script_path),
|
||||
scene_name,
|
||||
quality,
|
||||
"--media_dir",
|
||||
"promo/output",
|
||||
str(media_dir),
|
||||
]
|
||||
print(f"Rendering {scene_name}...")
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f" Done → promo/output/{scene_name}.mp4")
|
||||
print(f" Done → {media_dir / 'videos' / scene_name.lower()}.mp4")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue