rewrite continuous_batching as state-machine with batch token rotation, in/out flow, bitmask anim, and throughput bars

This commit is contained in:
ViperEkura 2026-05-06 21:53:15 +08:00
parent a7a79eef96
commit fc68fc9107
1 changed files with 292 additions and 125 deletions

View File

@ -1,158 +1,325 @@
"""AstrAI promo: Continuous Batching animation. """AstrAI promo: Continuous Batching — state-machine driven batch rotation.
Shows 4-phase pipeline with multiple requests concurrently at different stages, Shows a 4-state FSM (Cleanup Refill Prefill Decode Loop Cleanup)
and position-grouped decode batching the key advantage over static batching. with coloured batch tokens flowing through states, entering & leaving continuously.
""" """
from manim import * from manim import *
# ── palette ──
PHASE_COLORS = {
"Cleanup": GRAY,
"Refill": ORANGE,
"Prefill": BLUE,
"Decode": YELLOW,
}
BATCH_COLORS = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE, GOLD, MAROON]
class ContinuousBatching(Scene): class ContinuousBatching(Scene):
def construct(self): def construct(self):
# ═══════════════════════════════════════════════════
# 0. Title
# ═══════════════════════════════════════════════════
title = Text("Continuous Batching", font_size=48, color=BLUE) title = Text("Continuous Batching", font_size=48, color=BLUE)
self.play(Write(title)) self.play(Write(title))
self.wait(0.3) self.wait(0.4)
self.play(title.animate.to_edge(UP).scale(0.55)) self.play(title.animate.to_edge(UP).scale(0.55))
bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15)
self.play(Create(bar))
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.2) # ═══════════════════════════════════════════════════
self.play(Create(top_bar)) # 1. Build state-machine layout (vertical, 4 states)
# ═══════════════════════════════════════════════════
state_names = ["Cleanup", "Refill", "Prefill", "Decode"]
# ── 4-phase loop (vertical) ── states = VGroup()
phase_names = ["Cleanup", "Refill", "Prefill", "Decode (Batched)"] trans_arrows = VGroup()
phase_colors = [GRAY, ORANGE, BLUE, YELLOW] for i, name in enumerate(state_names):
phase_descs = [ box = RoundedRectangle(
"Evict finished slots", width=3.6, height=0.8, corner_radius=0.15,
"Admit new requests", color=PHASE_COLORS[name], fill_opacity=0.12, stroke_width=2.5,
"Compute KV cache", )
"Group by position", lbl = Text(name, font_size=20, color=PHASE_COLORS[name])
] states.add(VGroup(box, lbl))
phases = VGroup() states.arrange(DOWN, buff=0.3)
phase_arrows = VGroup() states.shift(LEFT * 3.8 + DOWN * 0.5)
for i, (name, color, desc) in enumerate(zip(phase_names, phase_colors, phase_descs)):
box = Rectangle(width=3.2, height=0.7, color=color, fill_opacity=0.12)
lbl = Text(name, font_size=18, color=color)
grp = VGroup(box, lbl)
phases.add(grp)
if i > 0:
a = Arrow(
phases[i - 1].get_bottom(),
phases[i].get_top(),
color=GRAY, buff=0.08,
max_tip_length_to_length_ratio=0.2,
)
phase_arrows.add(a)
phases.arrange(DOWN, buff=0.25) for i in range(1, 4):
phases.shift(LEFT * 3.5 + DOWN * 0.6) a = Arrow(
states[i - 1].get_bottom(), states[i].get_top(),
color=LIGHT_GRAY, buff=0.06,
max_tip_length_to_length_ratio=0.22,
)
trans_arrows.add(a)
for i in range(4): for i in range(4):
self.play(Create(phases[i])) self.play(Create(states[i]))
if i > 0: if i > 0:
self.play(Create(phase_arrows[i - 1])) self.play(Create(trans_arrows[i - 1]))
self.wait(0.3)
# cycle arrow back from Decode to Cleanup # loop arrow — Decode returns to Cleanup (multiturn decoding)
loop_arrow = CurvedArrow( loop = CurvedArrow(
phases[-1].get_right() + RIGHT * 0.15, states[-1].get_right() + RIGHT * 0.2,
phases[0].get_right() + RIGHT * 0.15, states[0].get_right() + RIGHT * 0.2,
color=GRAY, angle=PI / 2, color=LIGHT_GRAY, angle=PI / 2,
) )
loop_label = Text("Loop", font_size=12, color=GRAY).next_to(loop_arrow, RIGHT, buff=0.1) loop_lbl = Text("per token", font_size=11, color=GRAY).next_to(loop, RIGHT, buff=0.08)
self.play(Create(loop_arrow), Write(loop_label)) self.play(Create(loop), Write(loop_lbl))
self.wait(0.3)
# ── animate requests at different stages concurrently ──
colors = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE]
requests = []
def make_req(name, color):
dot = Dot(color=color, radius=0.13)
lbl = Text(name, font_size=14, color=color)
lbl.next_to(dot, LEFT, buff=0.15)
return VGroup(dot, lbl)
# R1 in Prefill, R2 in Decode, R3 in Refill (concurrent!)
r1 = make_req("R1", colors[0])
r1.next_to(phases[2], RIGHT, buff=1.2)
r2 = make_req("R2", colors[1])
r2.next_to(phases[3], RIGHT, buff=1.2)
r3 = make_req("R3", colors[2])
r3.next_to(phases[1], RIGHT, buff=1.2)
r4 = make_req("R4", colors[3])
r4.next_to(phases[0], RIGHT, buff=1.2)
self.play(FadeIn(r1, scale=0.7), FadeIn(r2, scale=0.7),
FadeIn(r3, scale=0.7), FadeIn(r4, scale=0.7))
self.wait(0.4) self.wait(0.4)
concurrent_note = Text("3 requests at different phases simultaneously", # ═══════════════════════════════════════════════════
font_size=15, color=WHITE).next_to(phases, DOWN, buff=0.5) # 2. Boot tokens — initial batches placed at mid-cycle
self.play(Write(concurrent_note)) # ═══════════════════════════════════════════════════
self.wait(1.2) def make_token(name: str, col: str) -> VGroup:
self.play(FadeOut(concurrent_note)) card = RoundedRectangle(width=0.65, height=0.38, corner_radius=0.08,
color=col, fill_opacity=0.35, stroke_width=1.8)
txt = Text(name, font_size=13, color=col)
return VGroup(card, txt)
# ── animate rotation through phases ── tokens = {
# R1: Prefill -> Decode, R2: Decode -> Cleanup, R3: Refill -> Prefill, R4: Cleanup -> Refill "A": make_token("A", BATCH_COLORS[0]),
self.play( "B": make_token("B", BATCH_COLORS[1]),
r1.animate.next_to(phases[3], RIGHT, buff=1.2), "C": make_token("C", BATCH_COLORS[2]),
r2.animate.next_to(phases[0], RIGHT, buff=1.2), "D": make_token("D", BATCH_COLORS[3]),
r3.animate.next_to(phases[2], RIGHT, buff=1.2), }
r4.animate.next_to(phases[1], RIGHT, buff=1.2), # assign to states
) tokens["A"].move_to(states[2]).shift(RIGHT * 1.5) # Prefill
self.wait(0.3) tokens["B"].move_to(states[3]).shift(RIGHT * 1.5) # Decode
tokens["C"].move_to(states[1]).shift(RIGHT * 1.5) # Refill
tokens["D"].move_to(states[0]).shift(RIGHT * 1.5) # Cleanup
# R2 done -> spawn R5 for t in tokens.values():
new_r5 = make_req("R5", colors[4]) self.play(FadeIn(t, scale=0.7), run_time=0.25)
new_r5.next_to(phases[1], RIGHT, buff=1.2)
self.play(FadeOut(r2), FadeIn(new_r5, scale=0.7))
r2_target = new_r5
self.wait(0.3)
# ── highlight decode: position-grouped batching ──
# gather multiple requests into decode
self.play(
r1.animate.next_to(phases[3], RIGHT, buff=1.2),
r3.animate.next_to(phases[3], RIGHT, buff=2.5),
FadeOut(r4), FadeOut(new_r5),
)
self.wait(0.2) self.wait(0.2)
ring = SurroundingRectangle(phases[3], color=YELLOW, buff=0.12) note = Text("4 batches distributed across 4 states", font_size=16, color=WHITE) \
ring_note = Text("Position-Grouped Decode\nSame pos same matmul batch", .next_to(states, DOWN, buff=0.55)
font_size=15, color=YELLOW, line_spacing=0.6) self.play(Write(note))
ring_note.next_to(ring, DOWN, buff=0.35) self.wait(1.0)
hbox = SurroundingRectangle( self.play(FadeOut(note))
VGroup(r1, r3).copy().arrange(RIGHT, buff=0.5).move_to(r1),
color=GREEN, buff=0.2, # ═══════════════════════════════════════════════════
# 3. Tick 1 — all tokens advance one state
# A: Prefill → Decode B: Decode → Cleanup
# C: Refill → Prefill D: Cleanup → Refill
# ═══════════════════════════════════════════════════
slots = [
states[0].get_center() + RIGHT * 1.5, # Cleanup
states[1].get_center() + RIGHT * 1.5, # Refill
states[2].get_center() + RIGHT * 1.5, # Prefill
states[3].get_center() + RIGHT * 1.5, # Decode
]
self.play(
tokens["A"].animate.move_to(slots[3]), # → Decode
tokens["B"].animate.move_to(slots[0]), # → Cleanup
tokens["C"].animate.move_to(slots[2]), # → Prefill
tokens["D"].animate.move_to(slots[1]), # → Refill
) )
hbox.next_to(phases[3], RIGHT, buff=1.85) self.wait(0.3)
self.play(Create(ring), Write(ring_note)) # B finished → replace with new token E
self.play(Create(hbox)) self.play(FadeOut(tokens["B"], scale=0.6))
self.wait(1.8) tokens["E"] = make_token("E", BATCH_COLORS[4])
self.play(FadeOut(ring), FadeOut(ring_note), FadeOut(hbox), tokens["E"].move_to(states[1]).shift(RIGHT * 1.5) # Refill
FadeOut(r1), FadeOut(r3), FadeOut(loop_arrow), FadeOut(loop_label)) self.play(FadeIn(tokens["E"], scale=0.7))
self.wait(0.25)
# O(1) slot allocation highlight # ═══════════════════════════════════════════════════
bitmask_box = VGroup( # 4. Tick 2 — advance again
Text("O(1) Slot Allocation via Bitmask", font_size=22, color=ORANGE), # ═══════════════════════════════════════════════════
Text("free_slots = ~occupied_mask one-bit op", font_size=16, color=GRAY), self.play(
).arrange(DOWN, buff=0.2).next_to(phases, DOWN, buff=0.8) tokens["A"].animate.move_to(slots[0]), # Decode → Cleanup
tokens["D"].animate.move_to(slots[2]), # Refill → Prefill
tokens["C"].animate.move_to(slots[3]), # Prefill → Decode
tokens["E"].animate.move_to(slots[1]), # (entered) → keeps Refill
)
self.wait(0.3)
self.play(Write(bitmask_box)) # A finished → replace with F
self.play(FadeOut(tokens["A"], scale=0.6))
tokens["F"] = make_token("F", BATCH_COLORS[5])
tokens["F"].move_to(states[1]).shift(RIGHT * 1.5)
self.play(FadeIn(tokens["F"], scale=0.7))
self.wait(0.25)
# ═══════════════════════════════════════════════════
# 5. Tick 3 — faster cycle, show pipeline never drains
# ═══════════════════════════════════════════════════
self.play(
tokens["C"].animate.move_to(slots[0]), # Decode → Cleanup
tokens["D"].animate.move_to(slots[3]), # Prefill → Decode
tokens["E"].animate.move_to(slots[2]), # Refill → Prefill
tokens["F"].animate.move_to(slots[1]), # → Refill
)
self.wait(0.25)
# C done → G enters
self.play(FadeOut(tokens["C"], scale=0.6))
tokens["G"] = make_token("G", BATCH_COLORS[6])
tokens["G"].move_to(states[1]).shift(RIGHT * 1.5)
self.play(FadeIn(tokens["G"], scale=0.7))
self.wait(0.35)
# drop note: constant throughput
flow_note = Text("Pipeline never drains — constant throughput",
font_size=15, color=GREEN).next_to(states, DOWN, buff=0.55)
self.play(Write(flow_note))
self.wait(1.5)
self.play(FadeOut(flow_note))
# clear tokens
self.play(*[FadeOut(t) for t in tokens.values()])
# ═══════════════════════════════════════════════════
# 6. Position-Grouped Decode highlight
# ═══════════════════════════════════════════════════
# show multiple tokens grouped at Decode
d_pos = states[3].get_center()
d_tokens = [
make_token("T" + str(i), BATCH_COLORS[i]) for i in range(4)
]
positions = [
d_pos + RIGHT * 1.2 + UP * 0.45,
d_pos + RIGHT * 1.2,
d_pos + RIGHT * 2.5 + UP * 0.45,
d_pos + RIGHT * 2.5,
]
for i in range(4):
d_tokens[i].move_to(positions[i])
self.play(FadeIn(d_tokens[i], scale=0.6), run_time=0.2)
ring = SurroundingRectangle(states[3], color=YELLOW, buff=0.12, stroke_width=3)
ring_txt = Text(
"Position-Grouped Batching\nSame decode position → single matmul",
font_size=14, color=YELLOW, line_spacing=0.6,
).next_to(states[3], DOWN, buff=0.5)
self.play(Create(ring), Write(ring_txt))
self.wait(2.0)
self.play(FadeOut(ring), FadeOut(ring_txt),
*[FadeOut(t) for t in d_tokens])
# ═══════════════════════════════════════════════════
# 7. O(1) Bitmask Slot Allocation
# ═══════════════════════════════════════════════════
bitmask_title = Text("O(1) Slot Allocation via Bitmask",
font_size=22, color=ORANGE).next_to(states, DOWN, buff=0.75)
bitmask_desc = Text("free_slots = ~occupied_mask (one-clock op)",
font_size=15, color=GRAY).next_to(bitmask_title, DOWN, buff=0.15)
self.play(Write(bitmask_title), Write(bitmask_desc))
self.wait(1.5)
# animate bitmask bits flipping
bits_group = VGroup()
bit_size = 0.18
for i in range(16):
square = Square(side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2)
if i in (2, 5, 9, 13):
square.set_fill(GRAY, opacity=0.5)
bits_group.add(square)
bits_group.arrange(RIGHT, buff=0.06)
bits_group.next_to(bitmask_desc, DOWN, buff=0.3)
occupied_lbl = Text("occupied_mask", font_size=11, color=RED).next_to(bits_group, LEFT, buff=0.4)
self.play(Create(bits_group), Write(occupied_lbl))
# flip to ~occupied
flipped = VGroup()
for i, sq in enumerate(bits_group):
copy_sq = Square(side_length=bit_size * 2, color=GRAY,
fill_opacity=0.0, stroke_width=1.2).move_to(sq)
if i not in (2, 5, 9, 13):
copy_sq.set_fill(GRAY, opacity=0.5)
flipped.add(copy_sq)
free_lbl = Text("free_slots", font_size=11, color=GREEN) \
.next_to(flipped, LEFT, buff=0.4).align_to(occupied_lbl, LEFT)
self.play(Transform(bits_group, flipped),
Transform(occupied_lbl, free_lbl))
self.wait(1.2) self.wait(1.2)
self.play(FadeOut(bits_group), FadeOut(occupied_lbl),
FadeOut(bitmask_title), FadeOut(bitmask_desc))
# ── clear for throughput ── # ═══════════════════════════════════════════════════
self.play(*[FadeOut(m) for m in self.mobjects if m is not title and m is not top_bar]) # 8. Throughput comparison with animated bars
# ═══════════════════════════════════════════════════
self.play(
*[FadeOut(m) for m in self.mobjects if m is not title and m is not bar],
FadeOut(loop), FadeOut(loop_lbl),
)
for s in states:
self.play(FadeOut(s), run_time=0.15)
for a in trans_arrows:
self.play(FadeOut(a), run_time=0.15)
compare = VGroup( self.wait(0.3)
Text("Throughput Comparison", font_size=32, color=BLUE),
Text("Static Batching: 1.0x (baseline)", font_size=22, color=RED), # bars
Text("Continuous Batching: 3.4x (single GPU)", font_size=22, color=GREEN), bar_base = LEFT * 2
).arrange(DOWN, buff=0.4, aligned_edge=LEFT) bar_max_w = 4.5
self.play(Write(compare)) bar_h = 0.5
self.wait(2)
self.play(FadeOut(compare), FadeOut(title), FadeOut(top_bar)) static_bar = Rectangle(width=0, height=bar_h, color=RED,
fill_opacity=0.7, stroke_width=0).move_to(bar_base, LEFT)
static_value = always_redraw(
lambda: DecimalNumber(
static_bar.get_width() / bar_max_w * 1.0,
num_decimal_places=1, font_size=24, color=WHITE,
).next_to(static_bar, RIGHT, buff=0.2)
)
static_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.2)
static_rect.move_to(bar_base, LEFT).shift(RIGHT * bar_max_w / 2)
cb_bar = Rectangle(width=0, height=bar_h, color=GREEN,
fill_opacity=0.7, stroke_width=0).move_to(bar_base, LEFT)
cb_value = always_redraw(
lambda: DecimalNumber(
cb_bar.get_width() / bar_max_w * 3.4,
num_decimal_places=1, font_size=24, color=WHITE,
).next_to(cb_bar, RIGHT, buff=0.2)
)
cb_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.2)
cb_rect.move_to(bar_base, LEFT).shift(RIGHT * bar_max_w / 2)
static_label = Text("Static Batching ", font_size=22, color=RED) \
.next_to(static_rect, LEFT, buff=0.3)
cb_label = Text("Continuous Batching", font_size=22, color=GREEN) \
.next_to(cb_rect, LEFT, buff=0.3)
labels = VGroup(static_label, cb_label).arrange(DOWN, buff=0.7, aligned_edge=LEFT)
labels.shift(LEFT * 2)
bar_group = VGroup(
static_rect, static_bar, static_value,
cb_rect, cb_bar, cb_value,
static_label, cb_label,
)
compare_title = Text("Throughput", font_size=30, color=BLUE).next_to(bar_group, UP, buff=0.6)
self.play(Write(compare_title))
self.play(Create(static_rect), Create(cb_rect),
Write(static_label), Write(cb_label))
# grow static bar
self.play(GrowFromEdge(static_bar, LEFT),
rate_func=linear, run_time=0.6)
self.wait(0.3)
# grow cb bar significantly faster
self.play(GrowFromEdge(cb_bar, LEFT),
rate_func=linear, run_time=0.6)
self.wait(0.4)
# labels under bars
static_num = Text("1.0x", font_size=22, color=RED) \
.next_to(static_rect, RIGHT, buff=0.3)
cb_num = Text("3.4x", font_size=22, color=GREEN) \
.next_to(cb_rect, RIGHT, buff=0.3)
self.play(Write(static_num), Write(cb_num))
self.wait(2.5)
self.play(*[FadeOut(m) for m in self.mobjects])