refactor: architecture — boxes left, descriptions right, 4-layer layout

This commit is contained in:
ViperEkura 2026-05-07 12:18:33 +08:00
parent 4d96a84fc5
commit 496f964979
1 changed files with 99 additions and 89 deletions

View File

@ -1,117 +1,127 @@
"""AstrAI promo: Full architecture overview — layer by layer introduction.""" """AstrAI promo: 4-layer architecture — boxes left, explanations right."""
from manim import * from manim import *
class Architecture(Scene): class Architecture(Scene):
"""Reveals AstrAI's 5-layer inference stack, introducing each layer.""" """Boxes on left, description text on right for each layer."""
def construct(self): def construct(self):
title = Text("AstrAI Architecture", font_size=44, color=BLUE) title = Text("AstrAI Architecture", font_size=42, color=BLUE)
title.to_edge(UP, buff=0.25)
self.play(Write(title)) self.play(Write(title))
self.play(title.animate.to_edge(UP, buff=0.3))
W = 8.0 W, BH = 5.2, 1.15
BX = -3.6
TX = 3.6
def box(h=1.05, color=GRAY, fill=0.08): def make_box(header, color, bits, src):
return Rectangle(width=W, height=h, color=color, fill_opacity=fill, stroke_width=1.5) b = Rectangle(width=W, height=BH, color=color, fill_opacity=0.1, stroke_width=1.5)
h = Text(header, font_size=16, color=color, weight=BOLD)
items = [h]
for line in bits:
items.append(Text(line, font_size=10, color=WHITE))
items.append(Text(src, font_size=9, color=GRAY))
c = VGroup(*items).arrange(DOWN, buff=0.04)
c.move_to(b.get_center())
return VGroup(b, c)
def layer_header(text, color): L1 = make_box("HTTP API Server", GREEN,
return Text(text, font_size=20, color=color, weight=BOLD) ["FastAPI · OpenAI-Compatible",
"/v1/chat/completions · SSE streaming"],
"astrai/inference/server.py")
def sub(text): L2 = make_box("Inference Engine", BLUE,
return Text(text, font_size=12, color=WHITE) ["generate() · batch mode · streaming",
"4-phase daemon: Cleanup → Refill → Prefill → Decode",
"Position-grouped decode · Bitmask O(1) slots"],
"astrai/inference/engine.py · scheduler.py")
def intro(title, detail, color, oneline=None): L3 = make_box("Prefix Cache + KV Cache", ORANGE,
"""Animate a layer: box + title → details → brief pause.""" ["Radix Tree prefix matching · LRU eviction",
b = box(color=color, fill=0.1) "Slot versioning · GPU copy_() zero-copy reuse"],
content = VGroup(title) "astrai/inference/scheduler.py")
if oneline:
content.add(oneline) L4 = make_box("Transformer Model", PURPLE,
if detail: ["24× DecoderBlock · GQA 6:1 · RoPE",
items = [title] "SwiGLU MLP · Dim 1536 · bfloat16"],
if oneline: "astrai/model/transformer.py")
items.append(oneline)
items.extend(detail) layers = VGroup(L1, L2, L3, L4)
content = VGroup(*items) layers.arrange(DOWN, buff=0.08)
else: layers.move_to([BX, 0, 0])
content = VGroup(title) if not oneline else VGroup(title, oneline) layers.next_to(title, DOWN, buff=0.25)
content.arrange(DOWN, buff=0.15)
content.move_to(b.get_center()) # Description panels (right side)
grp = VGroup(b, content) descs_text = [
["HTTP API Server",
"Receives chat requests via",
"OpenAI-compatible endpoints.",
"Streams generated tokens back",
"through Server-Sent Events."],
["Inference Engine",
"Orchestrates the full generation",
"pipeline with a background daemon.",
"4-phase loop: Cleanup tasks,",
"Refill batch, Prefill prompts,",
"Decode tokens one by one."],
["Prefix Cache + KV Cache",
"Caches key-value states using",
"a Radix Tree for O(n) prefix lookup.",
"Reuses matched prefixes via GPU",
"memcpy — zero recomputation."],
["Transformer Model (1B params)",
"Decoder-only Transformer with",
"Grouped-Query Attention (GQA 6:1).",
"RoPE rotary encoding, SwiGLU",
"activation, 100K vocabulary."],
]
def make_desc(lines, color):
els = [Text(lines[0], font_size=20, color=color, weight=BOLD)]
for ln in lines[1:]:
els.append(Text(ln, font_size=14, color=WHITE))
grp = VGroup(*els).arrange(DOWN, buff=0.1, aligned_edge=LEFT)
return grp return grp
layers = [] COLORS = [GREEN, BLUE, ORANGE, PURPLE]
descs = [make_desc(lns, c) for lns, c in zip(descs_text, COLORS)]
# ── Layer 1: API Server ──
l1_t = layer_header("HTTP API Server", GREEN)
l1_d = [sub("FastAPI • OpenAI-Compatible /v1/chat/completions"),
sub("Streaming SSE • Async • Health/Stats Endpoints")]
l1 = intro(l1_t, l1_d, GREEN, sub("astrai/inference/server.py"))
l1.next_to(title, DOWN, buff=0.35)
layers.append(l1)
# ── Layer 2: Inference Engine ──
l2_t = layer_header("InferenceEngine", BLUE)
l2_d = [sub("generate() · generate_async() · generate_with_request()"),
sub("Batch mode · Streaming (Generator) · Thread-safe accumulator")]
l2 = intro(l2_t, l2_d, BLUE, sub("astrai/inference/engine.py"))
l2.next_to(l1, DOWN, buff=0.12)
layers.append(l2)
# ── Layer 3: Continuous Batching Scheduler ──
l3_t = layer_header("InferenceScheduler (Background Daemon)", YELLOW)
l3_d = [sub("Cleanup → Refill → Prefill → Decode · 4-phase loop"),
sub("Position-Grouped Decode · Bitmask O(1) Slot Allocation")]
l3 = intro(l3_t, l3_d, YELLOW, sub("astrai/inference/scheduler.py"))
l3.next_to(l2, DOWN, buff=0.12)
layers.append(l3)
# ── Layer 4: Prefix Cache + KV Cache ──
l4_t = layer_header("PrefixCacheManager + KV Cache", ORANGE)
l4_d = [sub("Radix Tree prefix matching · LRU eviction · Slot versioning"),
sub("GPU copy_() → Zero-Copy Reuse · k_cache / v_cache tensors")]
l4 = intro(l4_t, l4_d, ORANGE, sub("astrai/inference/scheduler.py"))
l4.next_to(l3, DOWN, buff=0.12)
layers.append(l4)
# ── Layer 5: Transformer Model ──
l5_t = layer_header("Transformer (1B params)", PURPLE)
l5_d = [sub("24× DecoderBlock · GQA 6:1 · RoPE · SwiGLU MLP"),
sub("Dim 1536 · Max Length 2048 · bfloat16 · 100K vocab")]
l5 = intro(l5_t, l5_d, PURPLE, sub("astrai/model/transformer.py"))
l5.next_to(l4, DOWN, buff=0.12)
layers.append(l5)
# ── Animate layer by layer ──
arrows = VGroup() arrows = VGroup()
for i, layer in enumerate(layers): for i, (layer, desc) in enumerate(zip(layers, descs)):
self.play(Create(layer), run_time=0.4) b = layer[0]
self.wait(1.0 if i < 2 else 0.8) self.play(Create(layer), run_time=0.35)
if i > 0: desc.next_to(b, RIGHT, buff=1.0)
prev = layers[i - 1][0] desc.align_to(b, UP)
curr = layer[0] self.play(Write(desc), run_time=0.3)
self.wait(2.0 if i == 0 else 1.8)
if i < len(layers) - 1:
self.play(FadeOut(desc))
nxt = layers[i + 1][0]
arrow = Arrow( arrow = Arrow(
prev.get_bottom(), curr.get_top(), b.get_bottom(), nxt.get_top(),
color=GRAY, buff=0.06, color=GRAY, buff=0.04,
max_tip_length_to_length_ratio=0.18, max_tip_length_to_length_ratio=0.18,
) )
self.play(Create(arrow), run_time=0.15) self.play(Create(arrow), run_time=0.12)
arrows.add(arrow) arrows.add(arrow)
else:
self.wait(0.5)
self.play(FadeOut(desc))
self.wait(0.6) # Show all boxes + arrows together briefly
self.wait(0.3)
# ── Highlight: the innovation layers ── # Highlight innovation layers
hl3 = SurroundingRectangle(layers[2], color=YELLOW, buff=0.1, stroke_width=2) hl2 = SurroundingRectangle(L2, color=BLUE, buff=0.1, stroke_width=2)
hl4 = SurroundingRectangle(layers[3], color=ORANGE, buff=0.1, stroke_width=2) hl3 = SurroundingRectangle(L3, color=ORANGE, buff=0.1, stroke_width=2)
hl_note = Text("Key Innovations: Continuous Batching + Prefix Cache", hl_note = Text("Key Innovations", font_size=20, color=GOLD)
font_size=18, color=GOLD) hl_note.next_to(VGroup(hl2, hl3), RIGHT, buff=1.5)
hl_note.next_to(VGroup(hl3, hl4), LEFT, buff=0.5) hl_note.align_to(hl2, UP)
self.play(Create(hl3), Create(hl4), Write(hl_note)) self.play(Create(hl2), Create(hl3), Write(hl_note))
self.wait(2.0) self.wait(2.0)
self.play(FadeOut(hl3), FadeOut(hl4), FadeOut(hl_note)) self.play(FadeOut(hl2), FadeOut(hl3), FadeOut(hl_note))
# ── Fade to CTA ──
self.play(FadeOut(VGroup(*layers)), FadeOut(arrows)) self.play(FadeOut(VGroup(*layers)), FadeOut(arrows))
cta = VGroup( cta = VGroup(