refactor: architecture — boxes left, descriptions right, 4-layer layout
This commit is contained in:
parent
4d96a84fc5
commit
496f964979
188
architecture.py
188
architecture.py
|
|
@ -1,117 +1,127 @@
|
||||||
"""AstrAI promo: Full architecture overview — layer by layer introduction."""
|
"""AstrAI promo: 4-layer architecture — boxes left, explanations right."""
|
||||||
|
|
||||||
from manim import *
|
from manim import *
|
||||||
|
|
||||||
|
|
||||||
class Architecture(Scene):
|
class Architecture(Scene):
|
||||||
"""Reveals AstrAI's 5-layer inference stack, introducing each layer."""
|
"""Boxes on left, description text on right for each layer."""
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
title = Text("AstrAI Architecture", font_size=44, color=BLUE)
|
title = Text("AstrAI Architecture", font_size=42, color=BLUE)
|
||||||
|
title.to_edge(UP, buff=0.25)
|
||||||
self.play(Write(title))
|
self.play(Write(title))
|
||||||
self.play(title.animate.to_edge(UP, buff=0.3))
|
|
||||||
|
|
||||||
W = 8.0
|
W, BH = 5.2, 1.15
|
||||||
|
BX = -3.6
|
||||||
|
TX = 3.6
|
||||||
|
|
||||||
def box(h=1.05, color=GRAY, fill=0.08):
|
def make_box(header, color, bits, src):
|
||||||
return Rectangle(width=W, height=h, color=color, fill_opacity=fill, stroke_width=1.5)
|
b = Rectangle(width=W, height=BH, color=color, fill_opacity=0.1, stroke_width=1.5)
|
||||||
|
h = Text(header, font_size=16, color=color, weight=BOLD)
|
||||||
|
items = [h]
|
||||||
|
for line in bits:
|
||||||
|
items.append(Text(line, font_size=10, color=WHITE))
|
||||||
|
items.append(Text(src, font_size=9, color=GRAY))
|
||||||
|
c = VGroup(*items).arrange(DOWN, buff=0.04)
|
||||||
|
c.move_to(b.get_center())
|
||||||
|
return VGroup(b, c)
|
||||||
|
|
||||||
def layer_header(text, color):
|
L1 = make_box("HTTP API Server", GREEN,
|
||||||
return Text(text, font_size=20, color=color, weight=BOLD)
|
["FastAPI · OpenAI-Compatible",
|
||||||
|
"/v1/chat/completions · SSE streaming"],
|
||||||
|
"astrai/inference/server.py")
|
||||||
|
|
||||||
def sub(text):
|
L2 = make_box("Inference Engine", BLUE,
|
||||||
return Text(text, font_size=12, color=WHITE)
|
["generate() · batch mode · streaming",
|
||||||
|
"4-phase daemon: Cleanup → Refill → Prefill → Decode",
|
||||||
|
"Position-grouped decode · Bitmask O(1) slots"],
|
||||||
|
"astrai/inference/engine.py · scheduler.py")
|
||||||
|
|
||||||
def intro(title, detail, color, oneline=None):
|
L3 = make_box("Prefix Cache + KV Cache", ORANGE,
|
||||||
"""Animate a layer: box + title → details → brief pause."""
|
["Radix Tree prefix matching · LRU eviction",
|
||||||
b = box(color=color, fill=0.1)
|
"Slot versioning · GPU copy_() zero-copy reuse"],
|
||||||
content = VGroup(title)
|
"astrai/inference/scheduler.py")
|
||||||
if oneline:
|
|
||||||
content.add(oneline)
|
L4 = make_box("Transformer Model", PURPLE,
|
||||||
if detail:
|
["24× DecoderBlock · GQA 6:1 · RoPE",
|
||||||
items = [title]
|
"SwiGLU MLP · Dim 1536 · bfloat16"],
|
||||||
if oneline:
|
"astrai/model/transformer.py")
|
||||||
items.append(oneline)
|
|
||||||
items.extend(detail)
|
layers = VGroup(L1, L2, L3, L4)
|
||||||
content = VGroup(*items)
|
layers.arrange(DOWN, buff=0.08)
|
||||||
else:
|
layers.move_to([BX, 0, 0])
|
||||||
content = VGroup(title) if not oneline else VGroup(title, oneline)
|
layers.next_to(title, DOWN, buff=0.25)
|
||||||
content.arrange(DOWN, buff=0.15)
|
|
||||||
content.move_to(b.get_center())
|
# Description panels (right side)
|
||||||
grp = VGroup(b, content)
|
descs_text = [
|
||||||
|
["HTTP API Server",
|
||||||
|
"Receives chat requests via",
|
||||||
|
"OpenAI-compatible endpoints.",
|
||||||
|
"Streams generated tokens back",
|
||||||
|
"through Server-Sent Events."],
|
||||||
|
["Inference Engine",
|
||||||
|
"Orchestrates the full generation",
|
||||||
|
"pipeline with a background daemon.",
|
||||||
|
"4-phase loop: Cleanup tasks,",
|
||||||
|
"Refill batch, Prefill prompts,",
|
||||||
|
"Decode tokens one by one."],
|
||||||
|
["Prefix Cache + KV Cache",
|
||||||
|
"Caches key-value states using",
|
||||||
|
"a Radix Tree for O(n) prefix lookup.",
|
||||||
|
"Reuses matched prefixes via GPU",
|
||||||
|
"memcpy — zero recomputation."],
|
||||||
|
["Transformer Model (1B params)",
|
||||||
|
"Decoder-only Transformer with",
|
||||||
|
"Grouped-Query Attention (GQA 6:1).",
|
||||||
|
"RoPE rotary encoding, SwiGLU",
|
||||||
|
"activation, 100K vocabulary."],
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_desc(lines, color):
|
||||||
|
els = [Text(lines[0], font_size=20, color=color, weight=BOLD)]
|
||||||
|
for ln in lines[1:]:
|
||||||
|
els.append(Text(ln, font_size=14, color=WHITE))
|
||||||
|
grp = VGroup(*els).arrange(DOWN, buff=0.1, aligned_edge=LEFT)
|
||||||
return grp
|
return grp
|
||||||
|
|
||||||
layers = []
|
COLORS = [GREEN, BLUE, ORANGE, PURPLE]
|
||||||
|
descs = [make_desc(lns, c) for lns, c in zip(descs_text, COLORS)]
|
||||||
|
|
||||||
# ── Layer 1: API Server ──
|
|
||||||
l1_t = layer_header("HTTP API Server", GREEN)
|
|
||||||
l1_d = [sub("FastAPI • OpenAI-Compatible /v1/chat/completions"),
|
|
||||||
sub("Streaming SSE • Async • Health/Stats Endpoints")]
|
|
||||||
l1 = intro(l1_t, l1_d, GREEN, sub("astrai/inference/server.py"))
|
|
||||||
l1.next_to(title, DOWN, buff=0.35)
|
|
||||||
layers.append(l1)
|
|
||||||
|
|
||||||
# ── Layer 2: Inference Engine ──
|
|
||||||
l2_t = layer_header("InferenceEngine", BLUE)
|
|
||||||
l2_d = [sub("generate() · generate_async() · generate_with_request()"),
|
|
||||||
sub("Batch mode · Streaming (Generator) · Thread-safe accumulator")]
|
|
||||||
l2 = intro(l2_t, l2_d, BLUE, sub("astrai/inference/engine.py"))
|
|
||||||
l2.next_to(l1, DOWN, buff=0.12)
|
|
||||||
layers.append(l2)
|
|
||||||
|
|
||||||
# ── Layer 3: Continuous Batching Scheduler ──
|
|
||||||
l3_t = layer_header("InferenceScheduler (Background Daemon)", YELLOW)
|
|
||||||
l3_d = [sub("Cleanup → Refill → Prefill → Decode · 4-phase loop"),
|
|
||||||
sub("Position-Grouped Decode · Bitmask O(1) Slot Allocation")]
|
|
||||||
l3 = intro(l3_t, l3_d, YELLOW, sub("astrai/inference/scheduler.py"))
|
|
||||||
l3.next_to(l2, DOWN, buff=0.12)
|
|
||||||
layers.append(l3)
|
|
||||||
|
|
||||||
# ── Layer 4: Prefix Cache + KV Cache ──
|
|
||||||
l4_t = layer_header("PrefixCacheManager + KV Cache", ORANGE)
|
|
||||||
l4_d = [sub("Radix Tree prefix matching · LRU eviction · Slot versioning"),
|
|
||||||
sub("GPU copy_() → Zero-Copy Reuse · k_cache / v_cache tensors")]
|
|
||||||
l4 = intro(l4_t, l4_d, ORANGE, sub("astrai/inference/scheduler.py"))
|
|
||||||
l4.next_to(l3, DOWN, buff=0.12)
|
|
||||||
layers.append(l4)
|
|
||||||
|
|
||||||
# ── Layer 5: Transformer Model ──
|
|
||||||
l5_t = layer_header("Transformer (1B params)", PURPLE)
|
|
||||||
l5_d = [sub("24× DecoderBlock · GQA 6:1 · RoPE · SwiGLU MLP"),
|
|
||||||
sub("Dim 1536 · Max Length 2048 · bfloat16 · 100K vocab")]
|
|
||||||
l5 = intro(l5_t, l5_d, PURPLE, sub("astrai/model/transformer.py"))
|
|
||||||
l5.next_to(l4, DOWN, buff=0.12)
|
|
||||||
layers.append(l5)
|
|
||||||
|
|
||||||
# ── Animate layer by layer ──
|
|
||||||
arrows = VGroup()
|
arrows = VGroup()
|
||||||
for i, layer in enumerate(layers):
|
for i, (layer, desc) in enumerate(zip(layers, descs)):
|
||||||
self.play(Create(layer), run_time=0.4)
|
b = layer[0]
|
||||||
self.wait(1.0 if i < 2 else 0.8)
|
self.play(Create(layer), run_time=0.35)
|
||||||
if i > 0:
|
desc.next_to(b, RIGHT, buff=1.0)
|
||||||
prev = layers[i - 1][0]
|
desc.align_to(b, UP)
|
||||||
curr = layer[0]
|
self.play(Write(desc), run_time=0.3)
|
||||||
|
self.wait(2.0 if i == 0 else 1.8)
|
||||||
|
if i < len(layers) - 1:
|
||||||
|
self.play(FadeOut(desc))
|
||||||
|
nxt = layers[i + 1][0]
|
||||||
arrow = Arrow(
|
arrow = Arrow(
|
||||||
prev.get_bottom(), curr.get_top(),
|
b.get_bottom(), nxt.get_top(),
|
||||||
color=GRAY, buff=0.06,
|
color=GRAY, buff=0.04,
|
||||||
max_tip_length_to_length_ratio=0.18,
|
max_tip_length_to_length_ratio=0.18,
|
||||||
)
|
)
|
||||||
self.play(Create(arrow), run_time=0.15)
|
self.play(Create(arrow), run_time=0.12)
|
||||||
arrows.add(arrow)
|
arrows.add(arrow)
|
||||||
|
else:
|
||||||
|
self.wait(0.5)
|
||||||
|
self.play(FadeOut(desc))
|
||||||
|
|
||||||
self.wait(0.6)
|
# Show all boxes + arrows together briefly
|
||||||
|
self.wait(0.3)
|
||||||
|
|
||||||
# ── Highlight: the innovation layers ──
|
# Highlight innovation layers
|
||||||
hl3 = SurroundingRectangle(layers[2], color=YELLOW, buff=0.1, stroke_width=2)
|
hl2 = SurroundingRectangle(L2, color=BLUE, buff=0.1, stroke_width=2)
|
||||||
hl4 = SurroundingRectangle(layers[3], color=ORANGE, buff=0.1, stroke_width=2)
|
hl3 = SurroundingRectangle(L3, color=ORANGE, buff=0.1, stroke_width=2)
|
||||||
hl_note = Text("Key Innovations: Continuous Batching + Prefix Cache",
|
hl_note = Text("Key Innovations", font_size=20, color=GOLD)
|
||||||
font_size=18, color=GOLD)
|
hl_note.next_to(VGroup(hl2, hl3), RIGHT, buff=1.5)
|
||||||
hl_note.next_to(VGroup(hl3, hl4), LEFT, buff=0.5)
|
hl_note.align_to(hl2, UP)
|
||||||
self.play(Create(hl3), Create(hl4), Write(hl_note))
|
self.play(Create(hl2), Create(hl3), Write(hl_note))
|
||||||
self.wait(2.0)
|
self.wait(2.0)
|
||||||
self.play(FadeOut(hl3), FadeOut(hl4), FadeOut(hl_note))
|
self.play(FadeOut(hl2), FadeOut(hl3), FadeOut(hl_note))
|
||||||
|
|
||||||
# ── Fade to CTA ──
|
|
||||||
self.play(FadeOut(VGroup(*layers)), FadeOut(arrows))
|
self.play(FadeOut(VGroup(*layers)), FadeOut(arrows))
|
||||||
|
|
||||||
cta = VGroup(
|
cta = VGroup(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue