99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
"""AstrAI promo: Continuous Batching animation.
|
||
|
||
Shows how tasks flow through the 4-phase pipeline and get batched together.
|
||
"""
|
||
|
||
from manim import *
|
||
|
||
|
||
class ContinuousBatching(Scene):
|
||
"""Animates tasks flowing through the prefill->decode pipeline."""
|
||
|
||
def construct(self):
|
||
# ── title ──
|
||
title = Text("Continuous Batching", font_size=48, color=BLUE)
|
||
self.play(Write(title))
|
||
self.wait(0.5)
|
||
self.play(title.animate.to_edge(UP).scale(0.6))
|
||
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN)
|
||
self.play(Create(top_bar))
|
||
|
||
# ── pipeline stages ──
|
||
stage_names = ["Waiting\nQueue", "Prefill", "Decode\n(Batched)", "Finished"]
|
||
stage_color = [GRAY, BLUE, YELLOW, GREEN]
|
||
|
||
stages = VGroup()
|
||
arrows = VGroup()
|
||
for i, (name, color) in enumerate(zip(stage_names, stage_color)):
|
||
box = Rectangle(height=1.5, width=2.5, color=color, fill_opacity=0.12)
|
||
lbl = Text(name, font_size=18, color=color)
|
||
grp = VGroup(box, lbl)
|
||
grp.shift(RIGHT * (i - 1.5) * 3.2 + DOWN * 0.5)
|
||
stages.add(grp)
|
||
self.play(Create(grp), run_time=0.35)
|
||
if i > 0:
|
||
a = Arrow(stages[i - 1].get_right(), stages[i].get_left(), color=GRAY)
|
||
arrows.add(a)
|
||
self.play(Create(a), run_time=0.2)
|
||
|
||
pipeline = VGroup(stages, arrows)
|
||
plabel = Text("4-Phase Generation Loop", font_size=16, color=GRAY).next_to(
|
||
pipeline, DOWN, buff=0.4
|
||
)
|
||
self.play(Write(plabel))
|
||
self.wait(0.5)
|
||
|
||
# ── spawn tasks ──
|
||
task_colors = [YELLOW, ORANGE, PINK, TEAL, GREEN]
|
||
tasks = VGroup()
|
||
box_center = stages[0].get_center()
|
||
for i, c in enumerate(task_colors):
|
||
dot = Dot(color=c, radius=0.12)
|
||
y_off = (i - 2) * 0.2
|
||
dot.move_to(box_center + RIGHT * y_off * 0.3)
|
||
lbl = Text(f"R{i+1}", font_size=10, color=c).next_to(dot, UP, buff=0.1)
|
||
tg = VGroup(dot, lbl)
|
||
tasks.add(tg)
|
||
self.play(FadeIn(tg, scale=0.5), run_time=0.12)
|
||
|
||
self.wait(0.3)
|
||
|
||
# ── animate through stages ──
|
||
for phase in range(1, 4):
|
||
target = stages[phase].get_center()
|
||
anims = [t.animate.move_to(target) for t in tasks]
|
||
self.play(*anims, run_time=0.5, rate_func=smooth)
|
||
self.wait(0.15)
|
||
|
||
# ── highlight decode batching ──
|
||
ring = SurroundingRectangle(stages[2], color=YELLOW, buff=0.12)
|
||
note = Text(
|
||
"Same-position batch decoding", font_size=16, color=YELLOW
|
||
).next_to(stages[2], DOWN, buff=0.5)
|
||
self.play(Create(ring), Write(note))
|
||
self.wait(1)
|
||
self.play(FadeOut(ring), FadeOut(note))
|
||
|
||
# ── throughput comparison (text) ──
|
||
self.play(
|
||
*[FadeOut(t) for t in tasks],
|
||
FadeOut(pipeline),
|
||
FadeOut(plabel),
|
||
FadeOut(top_bar),
|
||
)
|
||
|
||
compare = VGroup(
|
||
Text("Throughput Comparison", font_size=32, color=BLUE),
|
||
Text(
|
||
"Static Batch: 1.0× (baseline)",
|
||
font_size=24, color=RED,
|
||
),
|
||
Text(
|
||
"Continuous Batching: 3.4× (single GPU)",
|
||
font_size=24, color=GREEN,
|
||
),
|
||
).arrange(DOWN, buff=0.4, aligned_edge=LEFT)
|
||
self.play(Write(compare))
|
||
self.wait(2)
|
||
self.play(FadeOut(compare))
|