452 lines
18 KiB
Python
452 lines
18 KiB
Python
"""AstrAI promo: Continuous Batching — state-machine driven batch rotation.
|
|
|
|
Shows a 4-state FSM (Cleanup → Refill → Prefill → Decode → Loop → Cleanup)
|
|
with coloured batch tokens flowing through states, entering & leaving continuously.
|
|
"""
|
|
|
|
from manim import *
|
|
|
|
Text.set_default(font="Times New Roman")
|
|
|
|
PAL = {
|
|
"Cleanup": GRAY,
|
|
"Refill": ORANGE,
|
|
"Prefill": BLUE,
|
|
"Decode": YELLOW,
|
|
}
|
|
BATCH_COLORS = [YELLOW, ORANGE, PINK, TEAL, GREEN, PURPLE, GOLD, MAROON]
|
|
|
|
|
|
class ContinuousBatching(Scene):
|
|
def construct(self):
|
|
# ═══════════════════════════════════════════════════
|
|
# 0. Title
|
|
# ═══════════════════════════════════════════════════
|
|
title = Text("Continuous Batching", font_size=48, color=BLUE)
|
|
self.play(Write(title))
|
|
self.wait(0.4)
|
|
self.play(title.animate.to_edge(UP).scale(0.55))
|
|
bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN, buff=0.15)
|
|
self.play(Create(bar))
|
|
|
|
# ── layout config ──
|
|
LANE_W, LANE_H = 2.6, 0.95
|
|
X_P, X_Pr, X_D, X_F = -4.95, -1.65, 1.65, 4.95
|
|
YL = 0.3
|
|
P_CLR, F_CLR = GRAY, RED
|
|
|
|
def lane(x, label, clr, sub):
|
|
box = RoundedRectangle(
|
|
width=LANE_W, height=LANE_H, corner_radius=0.12,
|
|
color=clr, fill_opacity=0.10, stroke_width=2.2,
|
|
)
|
|
t = Text(label, font_size=20, color=clr)
|
|
s = Text(sub, font_size=10, color=LIGHT_GRAY)
|
|
inner = VGroup(t, s).arrange(DOWN, buff=0.04).move_to(box)
|
|
return VGroup(box, inner).move_to([x, YL, 0])
|
|
|
|
# ── FSM label (unchanged) ──
|
|
fsm_states = VGroup()
|
|
for label, clr in [("Refill", ORANGE), ("→", LIGHT_GRAY),
|
|
("Prefill", BLUE), ("→", LIGHT_GRAY),
|
|
("Decode", YELLOW), ("→", LIGHT_GRAY),
|
|
("Cleanup", GRAY)]:
|
|
t = Text(label, font_size=13, color=clr)
|
|
fsm_states.add(t)
|
|
fsm_states.arrange(RIGHT, buff=0.06)
|
|
fsm_states.next_to(bar, DOWN, buff=0.3)
|
|
|
|
# ── 4-state pipeline lanes ──
|
|
pend_lane = lane(X_P, "PENDING", P_CLR, "waiting queue")
|
|
pref_lane = lane(X_Pr, "PREFILL", BLUE, "first token")
|
|
dec_lane = lane(X_D, "DECODE", YELLOW, "per-token gen")
|
|
fin_lane = lane(X_F, "FINISHED", F_CLR, "sequence done")
|
|
lane_group = VGroup(pend_lane, pref_lane, dec_lane, fin_lane)
|
|
|
|
# ── arrows ──
|
|
ea = Arrow(pend_lane.get_left() + LEFT * 0.9, pend_lane.get_left(),
|
|
color=GREEN, stroke_width=2.5,
|
|
max_tip_length_to_length_ratio=0.15)
|
|
el = Text("New Req", font_size=11, color=GREEN)
|
|
el.next_to(ea, UP, buff=0.04)
|
|
|
|
ra = Arrow(pend_lane.get_right(), pref_lane.get_left(),
|
|
color=ORANGE, buff=0.06,
|
|
max_tip_length_to_length_ratio=0.15)
|
|
rl = Text("Refill", font_size=10, color=ORANGE)
|
|
rl.next_to(ra, UP, buff=0.04)
|
|
|
|
ta = Arrow(pref_lane.get_right(), dec_lane.get_left(),
|
|
color=LIGHT_GRAY, buff=0.06,
|
|
max_tip_length_to_length_ratio=0.15)
|
|
|
|
ca = Arrow(dec_lane.get_right(), fin_lane.get_left(),
|
|
color=GRAY, buff=0.06,
|
|
max_tip_length_to_length_ratio=0.15)
|
|
cl = Text("Cleanup", font_size=10, color=GRAY)
|
|
cl.next_to(ca, UP, buff=0.04)
|
|
|
|
xa = Arrow(fin_lane.get_right(), fin_lane.get_right() + RIGHT * 0.9,
|
|
color=RED, stroke_width=2.5,
|
|
max_tip_length_to_length_ratio=0.15)
|
|
xl = Text("Exit", font_size=11, color=RED)
|
|
xl.next_to(xa, UP, buff=0.04)
|
|
|
|
self.play(Write(fsm_states))
|
|
self.play(Create(pend_lane), Create(pref_lane), Create(dec_lane), Create(fin_lane))
|
|
self.wait(0.3)
|
|
self.play(Create(ea), Write(el),
|
|
Create(ra), Write(rl),
|
|
Create(ta),
|
|
Create(ca), Write(cl),
|
|
Create(xa), Write(xl))
|
|
self.wait(0.5)
|
|
|
|
# ── Tokens ──
|
|
TOK_W, TOK_H = 0.58, 0.38
|
|
|
|
def mk_tok(name, col, state, n_tok):
|
|
card = RoundedRectangle(
|
|
width=TOK_W, height=TOK_H, corner_radius=0.06,
|
|
color=col, fill_opacity=0.38, stroke_width=1.6,
|
|
)
|
|
t = Text(name, font_size=13, color=col).move_to(card)
|
|
info = Text(f"{state} {n_tok}t", font_size=7, color=col)
|
|
return VGroup(VGroup(card, t), info).arrange(DOWN, buff=0.03)
|
|
|
|
def slots(x, n):
|
|
sp = LANE_W * 0.72 / max(n, 1)
|
|
sx = x - (n - 1) * sp / 2
|
|
return [np.array([sx + i * sp, -1.5, 0]) for i in range(n)]
|
|
|
|
P_SLOTS = slots(X_P, 2) # G, F
|
|
Pr_SLOTS = slots(X_Pr, 1) # E
|
|
D_SLOTS = slots(X_D, 3) # D, A, B
|
|
F_SLOTS = slots(X_F, 1) # C
|
|
|
|
tok = {}
|
|
def add(name, col, lane_slots, idx, state, n):
|
|
t = mk_tok(name, col, state, n).move_to(lane_slots[idx])
|
|
tok[name] = t
|
|
|
|
add("G", BATCH_COLORS[6], P_SLOTS, 0, "PENDING", 0)
|
|
add("F", BATCH_COLORS[5], P_SLOTS, 1, "PENDING", 0)
|
|
add("E", BATCH_COLORS[4], Pr_SLOTS, 0, "PREFILL", 128)
|
|
add("D", BATCH_COLORS[3], D_SLOTS, 0, "DECODE", 5)
|
|
add("A", BATCH_COLORS[0], D_SLOTS, 1, "DECODE", 9)
|
|
add("B", BATCH_COLORS[1], D_SLOTS, 2, "DECODE", 13)
|
|
add("C", BATCH_COLORS[2], F_SLOTS, 0, "FINISHED", 16)
|
|
|
|
for t in tok.values():
|
|
self.play(FadeIn(t, scale=0.7), run_time=0.18)
|
|
self.wait(2.0)
|
|
|
|
self.play(*[FadeOut(t) for t in tok.values()])
|
|
|
|
# ═══════════════════════════════════════════════════
|
|
# 7. Position-Grouped Decode highlight
|
|
# ═══════════════════════════════════════════════════
|
|
ring = SurroundingRectangle(dec_lane, color=YELLOW, buff=0.12, stroke_width=3)
|
|
ring_txt = Text(
|
|
"Position-Grouped Batching\nSame decode position → single matmul",
|
|
font_size=14, color=YELLOW, line_spacing=0.6,
|
|
).next_to(dec_lane, DOWN, buff=0.5)
|
|
self.play(Create(ring), Write(ring_txt))
|
|
self.wait(2.0)
|
|
self.play(FadeOut(ring), FadeOut(ring_txt))
|
|
|
|
# ═══════════════════════════════════════════════════
|
|
# 8. O(1) Bitmask Slot Allocation
|
|
# ═══════════════════════════════════════════════════
|
|
bitmask_title = Text(
|
|
"O(1) Slot Allocation via Bitmask",
|
|
font_size=22, color=ORANGE,
|
|
).next_to(lane_group, DOWN, buff=0.75)
|
|
bitmask_desc = Text(
|
|
"free_slots = ~occupied_mask (one-clock op)",
|
|
font_size=15, color=GRAY,
|
|
).next_to(bitmask_title, DOWN, buff=0.15)
|
|
self.play(Write(bitmask_title), Write(bitmask_desc))
|
|
self.wait(1.5)
|
|
|
|
bits_group = VGroup()
|
|
bit_size = 0.18
|
|
for i in range(16):
|
|
square = Square(
|
|
side_length=bit_size * 2, color=GRAY,
|
|
fill_opacity=0.0, stroke_width=1.2,
|
|
)
|
|
if i in (2, 5, 9, 13):
|
|
square.set_fill(GRAY, opacity=0.5)
|
|
bits_group.add(square)
|
|
bits_group.arrange(RIGHT, buff=0.06)
|
|
bits_group.next_to(bitmask_desc, DOWN, buff=0.3)
|
|
|
|
occupied_lbl = Text("occupied_mask", font_size=11, color=RED) \
|
|
.next_to(bits_group, LEFT, buff=0.4)
|
|
self.play(Create(bits_group), Write(occupied_lbl))
|
|
|
|
flipped = VGroup()
|
|
for i, sq in enumerate(bits_group):
|
|
copy_sq = Square(
|
|
side_length=bit_size * 2, color=GRAY,
|
|
fill_opacity=0.0, stroke_width=1.2,
|
|
).move_to(sq)
|
|
if i not in (2, 5, 9, 13):
|
|
copy_sq.set_fill(GRAY, opacity=0.5)
|
|
flipped.add(copy_sq)
|
|
free_lbl = Text("free_slots", font_size=11, color=GREEN) \
|
|
.next_to(flipped, LEFT, buff=0.4).align_to(occupied_lbl, LEFT)
|
|
|
|
self.play(Transform(bits_group, flipped),
|
|
Transform(occupied_lbl, free_lbl))
|
|
self.wait(1.2)
|
|
self.play(FadeOut(bits_group), FadeOut(occupied_lbl),
|
|
FadeOut(bitmask_title), FadeOut(bitmask_desc))
|
|
|
|
# ═══════════════════════════════════════════════════
|
|
# 9. Gantt timeline comparison — Static vs Continuous
|
|
# ═══════════════════════════════════════════════════
|
|
self.play(
|
|
*[FadeOut(m) for m in self.mobjects if m is not title and m is not bar],
|
|
)
|
|
self.wait(0.2)
|
|
|
|
CELL = 0.44
|
|
BH = 0.32
|
|
BGAP = 0.10
|
|
ROW = BH + BGAP
|
|
TICKS = 12
|
|
PANEL_W = TICKS * CELL
|
|
L_OX = -5.8
|
|
R_OX = 1.0
|
|
GY = 2.0
|
|
|
|
def gbox(ox, y, start, span, color, fill=0.75):
|
|
x = ox + start * CELL
|
|
w = span * CELL
|
|
return Rectangle(
|
|
width=w, height=BH, color=color,
|
|
fill_opacity=fill, stroke_width=0,
|
|
).move_to([x + w / 2, y, 0])
|
|
|
|
def batch_box(ox, y_gpu, y_last_req, start, span, color, label_txt):
|
|
w = span * CELL
|
|
top = y_gpu + BH / 2 + 0.06
|
|
bot = y_last_req - BH / 2 - 0.06
|
|
h = top - bot
|
|
cx = ox + (start + span / 2) * CELL
|
|
cy = (top + bot) / 2
|
|
rect = Rectangle(
|
|
width=w, height=h, color=color,
|
|
stroke_width=1.8, fill_opacity=0.04,
|
|
)
|
|
rect.move_to([cx, cy, 0])
|
|
lbl = Text(label_txt, font_size=12, color=color).next_to(rect, UP, buff=0.06)
|
|
return rect, lbl
|
|
|
|
def taxis(ox, ty):
|
|
line = Line(
|
|
[ox, ty, 0], [ox + PANEL_W, ty, 0],
|
|
color=GRAY, stroke_width=1.2,
|
|
)
|
|
ticks_vg = VGroup()
|
|
for t in range(TICKS + 1):
|
|
ti = Line(DOWN * 0.06, UP * 0.06, color=GRAY, stroke_width=0.8)
|
|
ti.move_to([ox + t * CELL, ty, 0])
|
|
ticks_vg.add(ti)
|
|
nums_vg = VGroup()
|
|
for t in range(0, TICKS + 1, 3):
|
|
n = Text(str(t), font_size=11, color=GRAY).next_to(
|
|
[ox + t * CELL, ty, 0], DOWN, buff=0.10,
|
|
)
|
|
nums_vg.add(n)
|
|
return VGroup(line, ticks_vg, nums_vg)
|
|
|
|
# ── Left: Static Batching ──
|
|
s_title = Text("Static Batching", font_size=26, color=RED)
|
|
s_title.move_to([L_OX + PANEL_W / 2, GY + 0.65, 0])
|
|
s_note = Text(
|
|
"requests wait → batch together → all run same length · GPU idle gaps",
|
|
font_size=13, color=RED,
|
|
).move_to([L_OX + PANEL_W / 2, -1.6, 0])
|
|
self.play(Write(s_title))
|
|
self.wait(0.25)
|
|
|
|
st_axis = taxis(L_OX, GY)
|
|
self.play(Create(st_axis))
|
|
|
|
gpu_l = Text("GPU", font_size=14, color=WHITE)
|
|
gpu_l.move_to([L_OX - 0.55, GY - ROW, 0])
|
|
self.play(Write(gpu_l))
|
|
|
|
s_y_gpu = GY - ROW
|
|
s_gpu_idle1 = gbox(L_OX, s_y_gpu, 0, 2, RED, 0.45)
|
|
s_gpu_batch1 = gbox(L_OX, s_y_gpu, 2, 4, GREEN)
|
|
s_gpu_idle2 = gbox(L_OX, s_y_gpu, 6, 2, RED, 0.45)
|
|
s_gpu_batch2 = gbox(L_OX, s_y_gpu, 8, 4, GREEN)
|
|
s_gpu_bars = [s_gpu_idle1, s_gpu_batch1, s_gpu_idle2, s_gpu_batch2]
|
|
for seg in s_gpu_bars:
|
|
self.play(GrowFromEdge(seg, LEFT), run_time=0.09)
|
|
|
|
s_idle1 = Text("IDLE", font_size=10, color=RED) \
|
|
.move_to([L_OX + 1 * CELL, s_y_gpu, 0])
|
|
s_idle2 = Text("IDLE", font_size=10, color=RED) \
|
|
.move_to([L_OX + 7 * CELL, s_y_gpu, 0])
|
|
self.play(Write(s_idle1), Write(s_idle2))
|
|
|
|
s_req_defs = [
|
|
("A", ORANGE, 0, 2, 2, 6),
|
|
("B", BLUE, 1, 2, 2, 6),
|
|
("C", PINK, 2, 2, 2, 6),
|
|
("D", ORANGE, 4, 8, 8, 12),
|
|
("E", BLUE, 6, 8, 8, 12),
|
|
("F", PINK, 8, 8, 8, 12),
|
|
]
|
|
s_bars = []
|
|
for i, (name, col, ws, we, rs, re) in enumerate(s_req_defs):
|
|
y = s_y_gpu - (i + 1) * ROW
|
|
lbl = Text(f"Req {name}", font_size=12, color=col)
|
|
lbl.move_to([L_OX - 0.55, y, 0])
|
|
items = [lbl]
|
|
anims = [FadeIn(lbl)]
|
|
if we - ws > 0.02:
|
|
wbar = gbox(L_OX, y, ws, we - ws, GRAY, 0.28)
|
|
items.append(wbar)
|
|
anims.append(GrowFromEdge(wbar, LEFT))
|
|
rbar = gbox(L_OX, y, rs, re - rs, col, 0.60)
|
|
items.append(rbar)
|
|
anims.append(GrowFromEdge(rbar, LEFT))
|
|
s_bars.extend(items)
|
|
self.play(*anims, run_time=0.09)
|
|
|
|
s_y_last3 = s_y_gpu - 3 * ROW
|
|
s_y_last6 = s_y_gpu - 6 * ROW
|
|
b1_rect, b1_lbl = batch_box(L_OX, s_y_gpu, s_y_last3, 2, 4, RED, "Batch 1")
|
|
b2_rect, b2_lbl = batch_box(L_OX, s_y_gpu, s_y_last6, 8, 4, RED, "Batch 2")
|
|
self.play(Create(b1_rect), Write(b1_lbl))
|
|
self.play(Create(b2_rect), Write(b2_lbl))
|
|
self.wait(0.8)
|
|
|
|
# ── Right: Continuous Batching ──
|
|
c_title = Text("Continuous Batching", font_size=26, color=GREEN)
|
|
c_title.move_to([R_OX + PANEL_W / 2, GY + 0.65, 0])
|
|
c_note = Text(
|
|
"no waiting · no padding · GPU never idle",
|
|
font_size=13, color=GREEN,
|
|
).move_to([R_OX + PANEL_W / 2, -1.6, 0])
|
|
self.play(Write(c_title))
|
|
self.wait(0.25)
|
|
|
|
ct_axis = taxis(R_OX, GY)
|
|
self.play(Create(ct_axis))
|
|
|
|
c_y_gpu = GY - ROW
|
|
cgpu_l = Text("GPU", font_size=14, color=WHITE)
|
|
cgpu_l.move_to([R_OX - 0.55, c_y_gpu, 0])
|
|
self.play(Write(cgpu_l))
|
|
|
|
c_gpu = gbox(R_OX, c_y_gpu, 0, 12, GREEN, 0.75)
|
|
self.play(GrowFromEdge(c_gpu, LEFT), run_time=0.5)
|
|
|
|
c_reqs = [
|
|
("A", ORANGE, 0, 4),
|
|
("B", BLUE, 1, 4),
|
|
("C", PINK, 2, 4),
|
|
("D", ORANGE, 4, 4),
|
|
("E", BLUE, 6, 4),
|
|
("F", PINK, 8, 4),
|
|
]
|
|
c_bars = []
|
|
c_n_reqs = len(c_reqs)
|
|
for i, (name, col, start, span) in enumerate(c_reqs):
|
|
y = c_y_gpu - (i + 1) * ROW
|
|
lbl = Text(f"Req {name}", font_size=12, color=col)
|
|
lbl.move_to([R_OX - 0.55, y, 0])
|
|
bar_rect = gbox(R_OX, y, start, span, col, 0.60)
|
|
c_bars.extend([lbl, bar_rect])
|
|
self.play(FadeIn(lbl), GrowFromEdge(bar_rect, LEFT), run_time=0.09)
|
|
self.wait(0.3)
|
|
|
|
c_y_last = c_y_gpu - c_n_reqs * ROW
|
|
c_box_rect, c_box_lbl = batch_box(R_OX, c_y_gpu, c_y_last, 0, 12, GREEN, "Always Serving")
|
|
self.play(Create(c_box_rect), Write(c_box_lbl))
|
|
self.wait(1.0)
|
|
|
|
s_count = Text("6 reqs · 2 batches · GPU idle gaps",
|
|
font_size=16, color=RED) \
|
|
.next_to(s_gpu_batch1, DOWN, buff=1.0).align_to(s_gpu_batch1, LEFT)
|
|
c_count = Text("6 reqs · continuous · GPU never idle",
|
|
font_size=16, color=GREEN) \
|
|
.next_to(c_gpu, DOWN, buff=1.0).align_to(c_gpu, LEFT)
|
|
self.play(Write(s_note), Write(c_note))
|
|
self.wait(0.3)
|
|
self.play(Write(s_count), Write(c_count))
|
|
self.wait(2.5)
|
|
self.play(FadeOut(s_count), FadeOut(c_count))
|
|
|
|
gantt_mobs = [
|
|
title, bar, s_title, s_note, c_title, c_note,
|
|
gpu_l, cgpu_l, s_idle1, s_idle2, st_axis, ct_axis,
|
|
*s_gpu_bars, c_gpu, *s_bars, *c_bars,
|
|
b1_rect, b1_lbl, b2_rect, b2_lbl, c_box_rect, c_box_lbl,
|
|
]
|
|
self.play(*[FadeOut(m) for m in gantt_mobs])
|
|
self.wait(0.2)
|
|
|
|
# ═══════════════════════════════════════════════════
|
|
# 10. Throughput comparison with animated bars
|
|
# ═══════════════════════════════════════════════════
|
|
compare_title = Text("Throughput Comparison", font_size=30, color=BLUE)
|
|
self.play(Write(compare_title))
|
|
self.wait(0.2)
|
|
self.play(compare_title.animate.to_edge(UP).scale(0.55))
|
|
self.wait(0.2)
|
|
|
|
bar_max_w = 5.0
|
|
bar_h = 0.55
|
|
row_gap = 0.8
|
|
ratio = 1.0 / 3.4
|
|
|
|
s_label = Text("Static Batching", font_size=24, color=RED)
|
|
s_rect = Rectangle(width=bar_max_w, height=bar_h, color=RED, stroke_width=1.5)
|
|
s_bar_rect = Rectangle(
|
|
width=bar_max_w * ratio, height=bar_h,
|
|
color=RED, fill_opacity=0.55, stroke_width=0,
|
|
)
|
|
s_num = Text("1.0x", font_size=24, color=RED)
|
|
|
|
c_label = Text("Continuous Batching", font_size=24, color=GREEN)
|
|
c_rect = Rectangle(width=bar_max_w, height=bar_h, color=GREEN, stroke_width=1.5)
|
|
c_bar_rect = Rectangle(
|
|
width=bar_max_w, height=bar_h,
|
|
color=GREEN, fill_opacity=0.55, stroke_width=0,
|
|
)
|
|
c_num = Text("3.4x", font_size=24, color=GREEN)
|
|
|
|
s_rect.move_to(ORIGIN + UP * (row_gap / 2 + bar_h / 2))
|
|
c_rect.move_to(ORIGIN + DOWN * (row_gap / 2 + bar_h / 2))
|
|
s_bar_rect.align_to(s_rect, LEFT).align_to(s_rect, UP)
|
|
c_bar_rect.align_to(c_rect, LEFT).align_to(c_rect, UP)
|
|
|
|
s_label.next_to(s_rect, LEFT, buff=0.4)
|
|
c_label.next_to(c_rect, LEFT, buff=0.4)
|
|
s_num.next_to(s_rect, RIGHT, buff=0.4)
|
|
c_num.next_to(c_rect, RIGHT, buff=0.4)
|
|
|
|
self.play(
|
|
Create(s_rect), Create(c_rect),
|
|
Write(s_label), Write(c_label),
|
|
)
|
|
self.wait(0.3)
|
|
|
|
self.play(GrowFromEdge(s_bar_rect, LEFT), rate_func=linear, run_time=0.6)
|
|
self.wait(0.3)
|
|
self.play(GrowFromEdge(c_bar_rect, LEFT), rate_func=linear, run_time=0.6)
|
|
self.wait(0.3)
|
|
|
|
self.play(Write(s_num), Write(c_num))
|
|
self.wait(2.5)
|
|
|
|
self.play(*[FadeOut(m) for m in self.mobjects])
|