video-promo/continuous_batching.py

159 lines
6.1 KiB
Python

"""AstrAI promo: Continuous Batching animation.
Shows 4-phase pipeline with multiple requests concurrently at different stages,
and position-grouped decode batching — the key advantage over static batching.
"""
from manim import *
class ContinuousBatching(Scene):
def construct(self):
title = Text("Continuous Batching", font_size=48, color=BLUE)
self.play(Write(title))
self.wait(0.3)
self.play(title.animate.to_edge(UP).scale(0.55))
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2)
self.play(Create(top_bar))
# ── 4-phase loop (vertical) ──
phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"]
phase_colors = [GRAY, ORANGE, BLUE, YELLOW]
phase_descs = [
"Evict finished slots",
"Admit new requests",
"Compute KV cache",
"Group by position",
]
phases = VGroup()
phase_arrows = VGroup()
for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)):
box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12)
lbl = Text(name, font_size=18, color=color)
grp = VGroup(box, lbl)
phases.add(grp)
if i > 0:
a = Arrow(
phases[i - 1].get_bottom(),
phases[i].get_top(),
color=GRAY, buff=0.08,
max_tip_length_to_length_ratio=0.2,
)
phase_arrows.add(a)
phases.arrange(DOWN, buff=0.25)
phases.shift(LEFT * 3.5 + DOWN * 0.6)
for i in range(4):
self.play(Create(phases[i]))
if i > 0:
self.play(Create(phase_arrows[i - 1]))
self.wait(0.3)
# cycle arrow back from Decode to Cleanup
loop_arrow = CurvedArrow(
phases[-1].get_right() + RIGHT * 0.15,
phases[0].get_right() + RIGHT * 0.15,
color=GRAY, angle=PI / 2,
)
loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1)
self.play(Create(loop_arrow), Write(loop_label))
self.wait(0.3)
# ── animate requests at different stages concurrently ──
colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE]
requests = []
def make_req(name, color):
dot = Dot(color=color, radius=0.13)
lbl = Text(name, font_size=14, color=color)
lbl.next_to(dot, LEFT, buff=0.15)
return VGroup(dot, lbl)
# R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!)
r1 = make_req("R1", colors[0])
r1.next_to(phases[2], RIGHT, buff=1.2)
r2 = make_req("R2", colors[1])
r2.next_to(phases[3], RIGHT, buff=1.2)
r3 = make_req("R3", colors[2])
r3.next_to(phases[1], RIGHT, buff=1.2)
r4 = make_req("R4", colors[3])
r4.next_to(phases[0], RIGHT, buff=1.2)
self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7),
FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7))
self.wait(0.4)
concurrent_note = Text("3 requests at different phases simultaneously",
font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5)
self.play(Write(concurrent_note))
self.wait(1.2)
self.play(FadeOut(concurrent_note))
# ── animate rotation through phases ──
# R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill
self.play(
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
r2.animate.next_to(phases[0], RIGHT, buff=1.2),
r3.animate.next_to(phases[2], RIGHT, buff=1.2),
r4.animate.next_to(phases[1], RIGHT, buff=1.2),
)
self.wait(0.3)
# R2 done -> spawn R5
new_r5 = make_req("R5", colors[4])
new_r5.next_to(phases[1], RIGHT, buff=1.2)
self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7))
r2_target = new_r5
self.wait(0.3)
# ── highlight decode: position-grouped batching ──
# gather multiple requests into decode
self.play(
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
r3.animate.next_to(phases[3], RIGHT, buff=2.5),
FadeOut(r4), FadeOut(new_r5),
)
self.wait(0.2)
ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12)
ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch",
font_size=15, color=YELLOW, line_spacing=0.6)
ring_note.next_to(ring, DOWN, buff=0.35)
hbox = SurroundingRectangle(
VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1),
color=GREEN, buff=0.2,
)
hbox.next_to(phases[3], RIGHT, buff=1.85)
self.play(Create(ring), Write(ring_note))
self.play(Create(hbox))
self.wait(1.8)
self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox),
FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label))
# O(1) slot allocation highlight
bitmask_box = VGroup(
Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE),
Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY),
).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8)
self.play(Write(bitmask_box))
self.wait(1.2)
# ── clear for throughput ──
self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar])
compare = VGroup(
Text("Throughput Comparison", font_size=32, color=BLUE),
Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED),
Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN),
).arrange(DOWN, buff=0.4, aligned_edge=LEFT)
self.play(Write(compare))
self.wait(2)
self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar))