159 lines
6.1 KiB
Python
159 lines
6.1 KiB
Python
"""AstrAI promo: Continuous Batching animation.
|
|
|
|
Shows 4-phase pipeline with multiple requests concurrently at different stages,
|
|
and position-grouped decode batching — the key advantage over static batching.
|
|
"""
|
|
|
|
from manim import *
|
|
|
|
|
|
class ContinuousBatching(Scene):
|
|
def construct(self):
|
|
title = Text("Continuous Batching", font_size=48, color=BLUE)
|
|
self.play(Write(title))
|
|
self.wait(0.3)
|
|
self.play(title.animate.to_edge(UP).scale(0.55))
|
|
|
|
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2)
|
|
self.play(Create(top_bar))
|
|
|
|
# ── 4-phase loop (vertical) ──
|
|
phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"]
|
|
phase_colors = [GRAY, ORANGE, BLUE, YELLOW]
|
|
phase_descs = [
|
|
"Evict finished slots",
|
|
"Admit new requests",
|
|
"Compute KV cache",
|
|
"Group by position",
|
|
]
|
|
|
|
phases = VGroup()
|
|
phase_arrows = VGroup()
|
|
for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)):
|
|
box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12)
|
|
lbl = Text(name, font_size=18, color=color)
|
|
grp = VGroup(box, lbl)
|
|
phases.add(grp)
|
|
if i > 0:
|
|
a = Arrow(
|
|
phases[i - 1].get_bottom(),
|
|
phases[i].get_top(),
|
|
color=GRAY, buff=0.08,
|
|
max_tip_length_to_length_ratio=0.2,
|
|
)
|
|
phase_arrows.add(a)
|
|
|
|
phases.arrange(DOWN, buff=0.25)
|
|
phases.shift(LEFT * 3.5 + DOWN * 0.6)
|
|
|
|
for i in range(4):
|
|
self.play(Create(phases[i]))
|
|
if i > 0:
|
|
self.play(Create(phase_arrows[i - 1]))
|
|
self.wait(0.3)
|
|
|
|
# cycle arrow back from Decode to Cleanup
|
|
loop_arrow = CurvedArrow(
|
|
phases[-1].get_right() + RIGHT * 0.15,
|
|
phases[0].get_right() + RIGHT * 0.15,
|
|
color=GRAY, angle=PI / 2,
|
|
)
|
|
loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1)
|
|
self.play(Create(loop_arrow), Write(loop_label))
|
|
self.wait(0.3)
|
|
|
|
# ── animate requests at different stages concurrently ──
|
|
colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE]
|
|
requests = []
|
|
|
|
def make_req(name, color):
|
|
dot = Dot(color=color, radius=0.13)
|
|
lbl = Text(name, font_size=14, color=color)
|
|
lbl.next_to(dot, LEFT, buff=0.15)
|
|
return VGroup(dot, lbl)
|
|
|
|
# R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!)
|
|
r1 = make_req("R1", colors[0])
|
|
r1.next_to(phases[2], RIGHT, buff=1.2)
|
|
r2 = make_req("R2", colors[1])
|
|
r2.next_to(phases[3], RIGHT, buff=1.2)
|
|
r3 = make_req("R3", colors[2])
|
|
r3.next_to(phases[1], RIGHT, buff=1.2)
|
|
|
|
r4 = make_req("R4", colors[3])
|
|
r4.next_to(phases[0], RIGHT, buff=1.2)
|
|
|
|
self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7),
|
|
FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7))
|
|
self.wait(0.4)
|
|
|
|
concurrent_note = Text("3 requests at different phases simultaneously",
|
|
font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5)
|
|
self.play(Write(concurrent_note))
|
|
self.wait(1.2)
|
|
self.play(FadeOut(concurrent_note))
|
|
|
|
# ── animate rotation through phases ──
|
|
# R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill
|
|
self.play(
|
|
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
|
|
r2.animate.next_to(phases[0], RIGHT, buff=1.2),
|
|
r3.animate.next_to(phases[2], RIGHT, buff=1.2),
|
|
r4.animate.next_to(phases[1], RIGHT, buff=1.2),
|
|
)
|
|
self.wait(0.3)
|
|
|
|
# R2 done -> spawn R5
|
|
new_r5 = make_req("R5", colors[4])
|
|
new_r5.next_to(phases[1], RIGHT, buff=1.2)
|
|
self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7))
|
|
r2_target = new_r5
|
|
|
|
self.wait(0.3)
|
|
|
|
# ── highlight decode: position-grouped batching ──
|
|
# gather multiple requests into decode
|
|
self.play(
|
|
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
|
|
r3.animate.next_to(phases[3], RIGHT, buff=2.5),
|
|
FadeOut(r4), FadeOut(new_r5),
|
|
)
|
|
self.wait(0.2)
|
|
|
|
ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12)
|
|
ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch",
|
|
font_size=15, color=YELLOW, line_spacing=0.6)
|
|
ring_note.next_to(ring, DOWN, buff=0.35)
|
|
hbox = SurroundingRectangle(
|
|
VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1),
|
|
color=GREEN, buff=0.2,
|
|
)
|
|
hbox.next_to(phases[3], RIGHT, buff=1.85)
|
|
|
|
self.play(Create(ring), Write(ring_note))
|
|
self.play(Create(hbox))
|
|
self.wait(1.8)
|
|
self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox),
|
|
FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label))
|
|
|
|
# O(1) slot allocation highlight
|
|
bitmask_box = VGroup(
|
|
Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE),
|
|
Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY),
|
|
).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8)
|
|
|
|
self.play(Write(bitmask_box))
|
|
self.wait(1.2)
|
|
|
|
# ── clear for throughput ──
|
|
self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar])
|
|
|
|
compare = VGroup(
|
|
Text("Throughput Comparison", font_size=32, color=BLUE),
|
|
Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED),
|
|
Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN),
|
|
).arrange(DOWN, buff=0.4, aligned_edge=LEFT)
|
|
self.play(Write(compare))
|
|
self.wait(2)
|
|
self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar))
|