add project source files
This commit is contained in:
parent
94aadb3d8f
commit
c03abd31fe
|
|
@ -0,0 +1,10 @@
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
output/
|
||||||
|
.DS_Store
|
||||||
|
|
@ -0,0 +1,373 @@
|
||||||
|
# AstrAI 宣传视频制作指南
|
||||||
|
|
||||||
|
> 本文档为制作 AstrAI 宣传视频提供完整的技术参考、分镜建议和录制脚本。
|
||||||
|
> 目标时长:**2-3 分钟**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 目录
|
||||||
|
|
||||||
|
1. [项目定位与核心卖点](#1-项目定位与核心卖点)
|
||||||
|
2. [技术架构速览](#2-技术架构速览)
|
||||||
|
3. [分镜脚本](#3-分镜脚本)
|
||||||
|
4. [演示录制指南](#4-演示录制指南)
|
||||||
|
5. [动画场景说明](#5-动画场景说明)
|
||||||
|
6. [旁白文案草稿](#6-旁白文案草稿)
|
||||||
|
7. [素材清单](#7-素材清单)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 项目定位与核心卖点
|
||||||
|
|
||||||
|
**一句话定位:**
|
||||||
|
> 一个能在单张消费级 GPU 上训练和推理的 1B 参数中英双语语言模型框架。
|
||||||
|
|
||||||
|
**核心卖点(视频中需突出):**
|
||||||
|
|
||||||
|
| 卖点 | 说明 | 视觉表达 |
|
||||||
|
|------|------|---------|
|
||||||
|
| **单卡可跑** | 1B 参数,RTX 3090/4090 即可运行 | 巨大服务器集群 vs 单张显卡对比 |
|
||||||
|
| **连续批处理** | 动态合并请求,吞吐量 3x+ | 任务流经 Cleanup→Refill→Prefill→Decode 动画 |
|
||||||
|
| **前缀缓存零拷贝** | 相同前缀直接复用 KV,无需重算 | Radix Tree 生长动画 |
|
||||||
|
| **OpenAI 兼容 API** | 一行代码切换 | curl 命令对比 |
|
||||||
|
| **流式输出** | 逐 token 返回,低首延迟 | 终端逐字喷出效果 |
|
||||||
|
| **全过程开源** | 训练+推理+权重全部开源 | GitHub 页面展示 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 技术架构速览
|
||||||
|
|
||||||
|
### 整体架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────┐
|
||||||
|
│ FastAPI Server (OpenAI-Compatible API) │
|
||||||
|
├──────────────────────────────────────────────────┤
|
||||||
|
│ InferenceEngine (Streaming + Async + Batch) │
|
||||||
|
├──────────────────────────────────────────────────┤
|
||||||
|
│ Continuous Batching Scheduler │
|
||||||
|
│ ┌────────┐ ┌──────┐ ┌────────┐ ┌────────┐ │
|
||||||
|
│ │Cleanup │→ │Refill│→ │Prefill │→ │ Decode │ │
|
||||||
|
│ └────────┘ └──────┘ └────────┘ └────────┘ │
|
||||||
|
├──────────────────────────────────────────────────┤
|
||||||
|
│ Prefix Cache (Radix Tree) + KV Cache │
|
||||||
|
├──────────────────────────────────────────────────┤
|
||||||
|
│ Transformer (24层 GQA, RoPE, SwiGLU) │
|
||||||
|
└──────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关键技术指标
|
||||||
|
|
||||||
|
| 指标 | 值 |
|
||||||
|
|------|------|
|
||||||
|
| 参数量 | ~1.0B |
|
||||||
|
| 词表大小 | 100,000(中英 BPE) |
|
||||||
|
| 层数 | 24 |
|
||||||
|
| 注意力头 | 24 Q-heads / 4 KV-heads(GQA) |
|
||||||
|
| 最大长度 | 2048 tokens |
|
||||||
|
| 精度 | bfloat16 |
|
||||||
|
| 最低显存 | ~6GB(推理)/~12GB(训练) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 分镜脚本
|
||||||
|
|
||||||
|
总时长 **2:30**,分为 6 个段落。
|
||||||
|
|
||||||
|
### Segment 1:Hook + 问题陈述(0:00 - 0:20)
|
||||||
|
|
||||||
|
| 镜头 | 画面 | 旁白 | 时长 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| 1.1 | 黑屏,逐字打出"大语言模型很强大" | "大语言模型很强大——" | 3s |
|
||||||
|
| 1.2 | 切到数据中心照片 / 巨大 GPU 集群 | "——但跑起来需要几十张 GPU,普通人根本碰不到。" | 5s |
|
||||||
|
| 1.3 | 画面分屏:左边集群,右边一张 RTX 4090 | "但如果我告诉你,只要一张显卡就够了呢?" | 5s |
|
||||||
|
| 1.4 | Logo 出现:**AstrAI**,下方副标题 "1B 参数单卡推理框架" | "AstrAI——单卡跑大模型。" | 7s |
|
||||||
|
|
||||||
|
**视觉素材**:数据中心图片(可免版权下载)、RTX 4090 产品图、Logo 动画
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Segment 2:模型架构速览(0:20 - 0:45)
|
||||||
|
|
||||||
|
| 镜头 | 画面 | 旁白 | 时长 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| 2.1 | Transformer 架构图逐层展开:embed → 24层 decoder → norm → head | "1B 参数,24 层 Transformer,100k 词表的中英 BPE 分词器。" | 8s |
|
||||||
|
| 2.2 | 高亮 GQA:24个 Q head 映射到 4个 KV head | "GQA 分组查询注意力——24 个查询头只对应 4 个 KV 头,KV 缓存直接减少 83%。" | 10s |
|
||||||
|
| 2.3 | RoPE 旋转变换可视化 | "RoPE 旋转位置编码,支持动态长度外推。" | 5s |
|
||||||
|
| 2.4 | fade 到模型 card:vocab=100k, dim=1536, layers=24, heads=24, kv_heads=4 | 静默 | 2s |
|
||||||
|
|
||||||
|
**视觉素材**:`architecture.py` 动画、模型参数 card
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Segment 3:连续批处理(0:45 - 1:20)
|
||||||
|
|
||||||
|
| 镜头 | 画面 | 旁白 | 时长 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| 3.1 | 3 个请求同时到达服务器 | "当多个请求同时到达时——" | 3s |
|
||||||
|
| 3.2 | 静态批处理对比:最长补齐,3个请求串行 → 总耗时 max_len × 3 | "传统做法是静态批处理,把请求补齐到相同长度,串行处理,GPU 利用率低下。" | 8s |
|
||||||
|
| 3.3 | 连续批处理动画:任务流入 Waiting Queue → Cleanup → Refill → Prefill → Decode | "AstrAI 采用连续批处理:任务动态进出,GPU 每一刻都在满负荷运转。" | 10s |
|
||||||
|
| 3.4 | 放大 Decode 阶段:同一位置的任务合并成一批 | "特别地,只有处于相同 KV 缓存位置的任务才一起解码,从根本上避免了 RoPE 位置错乱的问题。" | 8s |
|
||||||
|
| 3.5 | 吞吐对比柱状图:Static Batch vs Continuous Batching (3x+) | "实测吞吐量提升 3 倍以上。" | 6s |
|
||||||
|
|
||||||
|
**视觉素材**:`continuous_batching.py` 动画、对比图表
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Segment 4:前缀缓存(1:20 - 1:50)
|
||||||
|
|
||||||
|
| 镜头 | 画面 | 旁白 | 时长 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| 4.1 | 两个请求有相同 system prompt:"你是一个AI助手" | "如果两个请求有相同的前缀——比如相同的系统提示词——" | 5s |
|
||||||
|
| 4.2 | 普通做法:两个请求各自独立计算前 20 个 token | "普通框架会各自从头计算一遍,白白浪费算力。" | 5s |
|
||||||
|
| 4.3 | Radix Tree 生长动画:第一个请求插入,第二个请求匹配共享前缀 | "AstrAI 用一颗字典树缓存所有前缀的 KV——第二个请求直接命中。" | 8s |
|
||||||
|
| 4.4 | 高亮 Slot 复用:直接用原 slot 继续写,零拷贝 | "如果原始 slot 空闲,直接原地续写,连 GPU 内存拷贝都不需要。" | 7s |
|
||||||
|
| 4.5 | 首 token 延迟对比:有缓存 vs 无缓存(-50%) | "首 token 延迟降低一半以上。" | 5s |
|
||||||
|
|
||||||
|
**视觉素材**:`prefix_cache.py` 动画、延迟对比
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Segment 5:Demo 演示(1:50 - 2:15)
|
||||||
|
|
||||||
|
| 镜头 | 画面 | 旁白 | 时长 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| 5.1 | 侧录终端:启动 stream_chat.py,逐行输出对话 | "来实际看看效果。" | 10s |
|
||||||
|
| 5.2 | 多轮对话:中文问答,逐 token 喷出 | 静默 + 打字音效 | 8s |
|
||||||
|
| 5.3 | 切到 HTTP 模式:服务端 + curl 请求,流式返回 | "也提供 OpenAI 兼容的 HTTP API,一行 curl 就能调用。" | 7s |
|
||||||
|
|
||||||
|
**视觉素材**:终端录屏(OBS 录制)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Segment 6:收尾 + CTA(2:15 - 2:30)
|
||||||
|
|
||||||
|
| 镜头 | 画面 | 旁白 | 时长 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| 6.1 | 全栈流程回顾(缩略架构图) | "训练用 SEQ → SFT → DPO/GRPO,推理用连续批处理——" | 5s |
|
||||||
|
| 6.2 | GitHub 页面 + Star 引导 | "——全部开源。点个 Star,一起让大模型更普惠。" | 7s |
|
||||||
|
| 6.3 | Logo + URL + "Open Source • Single GPU" | 静默 | 3s |
|
||||||
|
|
||||||
|
**视觉素材**:GitHub 页面录屏、Logo 定版
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 演示录制指南
|
||||||
|
|
||||||
|
### 4.1 准备工作
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 安装依赖
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
|
# 2. 下载模型(约 7GB)
|
||||||
|
python scripts/demo/download.py
|
||||||
|
|
||||||
|
# 3. 验证模型加载
|
||||||
|
python scripts/demo/generate_ar.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 录制场景 A:交互式对话
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 终端 1:启动交互式对话
|
||||||
|
python scripts/demo/stream_chat.py
|
||||||
|
|
||||||
|
# 预期交互
|
||||||
|
>> 你好?
|
||||||
|
AstrAI: 你好!有什么我可以帮你的吗?
|
||||||
|
>> 请用中文介绍一下你自己
|
||||||
|
AstrAI: ...(逐 token 输出)
|
||||||
|
>> 编一个关于人工智能的短故事
|
||||||
|
AstrAI: ...(逐 token 输出)
|
||||||
|
```
|
||||||
|
|
||||||
|
**录制重点**:
|
||||||
|
- 逐 token 流式输出效果(用 OBS 录制终端窗口)
|
||||||
|
- 多轮对话的记忆能力(跨轮上下文保持)
|
||||||
|
- 打字音效叠加
|
||||||
|
|
||||||
|
### 4.3 录制场景 B:HTTP 服务 + 并发
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 终端 1:启动服务器
|
||||||
|
python -m scripts.tools.server --port 8000 --device cuda
|
||||||
|
|
||||||
|
# 终端 2:发送请求(非流式)
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"messages":[{"role":"user","content":"Hello!"}],"stream":false}'
|
||||||
|
|
||||||
|
# 终端 3:流式请求
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"messages":[{"role":"user","content":"Write a poem"}],"stream":true}'
|
||||||
|
|
||||||
|
# 终端 4:并发压测(用 scripts/demo/generate_batch.py)
|
||||||
|
python scripts/demo/generate_batch.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**录制重点**:
|
||||||
|
- 同时多个 curl 请求展示并发处理
|
||||||
|
- 服务端日志显示批处理合并
|
||||||
|
- `/stats` 端点展示实时统计
|
||||||
|
|
||||||
|
### 4.4 录制规格
|
||||||
|
|
||||||
|
| 参数 | 建议 |
|
||||||
|
|------|------|
|
||||||
|
| 分辨率 | 1920×1080 |
|
||||||
|
| 帧率 | 30fps |
|
||||||
|
| 终端 | Windows Terminal 或 iTerm2,深色主题 |
|
||||||
|
| 字号 | 16-18px,等宽字体(JetBrains Mono / Cascadia Code) |
|
||||||
|
| 录屏工具 | OBS Studio(免费) |
|
||||||
|
| 音频 | 旁白用 USB 麦克风,音效后期叠加 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 动画场景说明
|
||||||
|
|
||||||
|
位于 `promo/` 目录,使用 Manim 引擎。
|
||||||
|
|
||||||
|
### 安装 Manim
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# conda 环境内安装
|
||||||
|
pip install manim
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
python -c "import manim; print(manim.__version__)"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 渲染命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 单独渲染一个场景
|
||||||
|
manim -qh promo/continuous_batching.py ContinuousBatching
|
||||||
|
|
||||||
|
# 全部场景渲染
|
||||||
|
python promo/render_all.py
|
||||||
|
|
||||||
|
# 快速草稿(480p,适合调试)
|
||||||
|
manim -ql promo/continuous_batching.py ContinuousBatching
|
||||||
|
```
|
||||||
|
|
||||||
|
输出文件为 `promo/output/videos/` 下的 `.mp4` 文件,可直接导入剪辑软件。
|
||||||
|
|
||||||
|
### 场景清单
|
||||||
|
|
||||||
|
| 文件 | 导出场景名 | 内容 | 建议时长 |
|
||||||
|
|------|-----------|------|---------|
|
||||||
|
| `transformer.py` | `Transformer` | 模型架构:Embed → GQA → SwiGLU → ×24 → LM Head | ~35s |
|
||||||
|
| `continuous_batching.py` | `ContinuousBatching` | 4 阶段流水线动画 + 吞吐对比 | ~30s |
|
||||||
|
| `prefix_cache.py` | `PrefixCache` | Radix Tree 生长 + 多分支前缀复用 | ~30s |
|
||||||
|
| `architecture.py` | `Architecture` | 全栈架构逐层展开 + 数据流 | ~25s |
|
||||||
|
|
||||||
|
### 自定义动画
|
||||||
|
|
||||||
|
如需修改动画内容:
|
||||||
|
- Manim 语法参考:https://docs.manim.community/
|
||||||
|
- 所有动画元素(颜色、位置、速度)在场景类中通过参数调整
|
||||||
|
- 中文字体渲染需额外配置:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在场景类开头添加
|
||||||
|
Text.set_default(font="Microsoft YaHei")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 旁白文案草稿
|
||||||
|
|
||||||
|
### 中文版(完整 2:30)
|
||||||
|
|
||||||
|
```
|
||||||
|
[00:00] 大语言模型很强大——
|
||||||
|
[00:03] 但跑起来需要几十张 GPU,普通人根本碰不到。
|
||||||
|
[00:08] 但如果我告诉你,只要一张显卡就够了呢?
|
||||||
|
[00:13] AstrAI——单卡跑大模型。
|
||||||
|
|
||||||
|
[00:20] 1B 参数,24 层 Transformer,100k 词表的中英 BPE 分词器。
|
||||||
|
[00:28] GQA 分组查询注意力——24 个查询头只对应 4 个 KV 头,KV 缓存直接减少 83%。
|
||||||
|
[00:38] RoPE 旋转位置编码,支持动态长度外推。
|
||||||
|
|
||||||
|
[00:45] 当多个请求同时到达时——
|
||||||
|
[00:48] 传统做法是静态批处理,把请求补齐到相同长度串行处理,GPU 利用率低下。
|
||||||
|
[00:56] AstrAI 采用连续批处理:任务动态进出,GPU 每一刻都在满负荷运转。
|
||||||
|
[01:06] 只有处于相同 KV 缓存位置的任务才一起解码,从根本上避免 RoPE 位置错乱。
|
||||||
|
[01:14] 实测吞吐量提升 3 倍以上。
|
||||||
|
|
||||||
|
[01:20] 如果两个请求有相同的前缀,普通框架会各自从头计算。
|
||||||
|
[01:25] AstrAI 用一颗字典树缓存所有前缀的 KV——第二个请求直接命中。
|
||||||
|
[01:33] 如果原始 slot 空闲,直接原地续写,连 GPU 内存拷贝都不需要。
|
||||||
|
[01:40] 首 token 延迟降低一半以上。
|
||||||
|
|
||||||
|
[01:50] 来实际看看效果。
|
||||||
|
[01:52] (现场演示部分,自由发挥)
|
||||||
|
|
||||||
|
[02:15] 训练到推理,全流程开源,点个 Star,一起让大模型更普惠。
|
||||||
|
[02:25] AstrAI — Open Source, Single GPU.
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 素材清单
|
||||||
|
|
||||||
|
### 视频素材
|
||||||
|
|
||||||
|
| 素材 | 来源 | 状态 |
|
||||||
|
|------|------|------|
|
||||||
|
| 数据中心 / GPU 集群图片 | Pexels / Unsplash 免版权 | 需下载 |
|
||||||
|
| RTX 4090 产品图 | NVIDIA 官网 / 实物拍摄 | 需准备 |
|
||||||
|
| AstrAI Logo | `assets/images/logo.png` | ✅ 已有 |
|
||||||
|
| 终端录屏(对话) | OBS 录制 `scripts/demo/stream_chat.py` | 需录制 |
|
||||||
|
| 终端录屏(HTTP) | OBS 录制 curl + server | 需录制 |
|
||||||
|
| 终端录屏(并发) | OBS 录制 `generate_batch.py` | 需录制 |
|
||||||
|
| GitHub 页面 | 浏览器录屏 | 需录制 |
|
||||||
|
| Transformer 架构动画 | Manim 渲染 `transformer.py` | ✅ 已渲染 |
|
||||||
|
| 架构动画 | Manim 渲染 `architecture.py` | ✅ 已渲染 |
|
||||||
|
| 连续批处理动画 | Manim 渲染 `continuous_batching.py` | ✅ 已渲染 |
|
||||||
|
| 前缀缓存动画 | Manim 渲染 `prefix_cache.py` | ✅ 已渲染 |
|
||||||
|
|
||||||
|
### 音频素材
|
||||||
|
|
||||||
|
| 素材 | 建议 |
|
||||||
|
|------|------|
|
||||||
|
| 旁白 | USB 麦克风录制,男声或女声,中文普通话 |
|
||||||
|
| 背景音乐 | Epidemic Sound / YouTube Audio Library 搜索 "technology ambient" |
|
||||||
|
| 音效 | 打字音效(terminal keystrokes)、转场 swoosh、whoosh |
|
||||||
|
|
||||||
|
### 软件工具
|
||||||
|
|
||||||
|
| 用途 | 推荐工具 | 价格 |
|
||||||
|
|------|---------|------|
|
||||||
|
| 录屏 | OBS Studio | 免费 |
|
||||||
|
| 剪辑 | DaVinci Resolve | 免费 |
|
||||||
|
| 动画渲染 | Manim (`pip install manim`) | 免费 |
|
||||||
|
| 音频处理 | Audacity | 免费 |
|
||||||
|
| 字幕 | DaVinci Resolve 内建 / Aegisub | 免费 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:关键文件索引
|
||||||
|
|
||||||
|
| 文件路径 | 说明 |
|
||||||
|
|---------|------|
|
||||||
|
| `README.md` | 项目主页 README,含快速开始 |
|
||||||
|
| `assets/docs/introduction.md` | 模型架构深度介绍 |
|
||||||
|
| `assets/docs/design.md` | 设计文档 + UML 类图 |
|
||||||
|
| `astrai/inference/scheduler.py` | 连续批处理调度器核心代码 |
|
||||||
|
| `astrai/inference/engine.py` | 推理引擎统一接口 |
|
||||||
|
| `astrai/inference/server.py` | FastAPI 服务器 |
|
||||||
|
| `astrai/model/transformer.py` | Transformer 模型 |
|
||||||
|
| `astrai/model/module.py` | GQA、MLA、MLP 等模块 |
|
||||||
|
| `scripts/demo/stream_chat.py` | 交互式对话演示 |
|
||||||
|
| `scripts/demo/generate_batch.py` | 批量生成演示 |
|
||||||
|
| `scripts/tools/server.py` | HTTP 服务启动脚本 |
|
||||||
|
| `scripts/tools/benchmark.py` | 性能基准测试 |
|
||||||
|
| `scripts/promo/README.md` | 动画渲染说明(已移至 promo/) |
|
||||||
|
| `promo/render_all.py` | 一键渲染所有动画 |
|
||||||
|
| `promo/continuous_batching.py` | 连续批处理 Manim 场景 |
|
||||||
|
| `promo/prefix_cache.py` | 前缀缓存 Manim 场景 |
|
||||||
|
| `promo/architecture.py` | 架构总览 Manim 场景 |
|
||||||
|
| `params/config.json` | 模型配置 |
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""AstrAI promo: Full architecture overview."""
|
||||||
|
|
||||||
|
from manim import *
|
||||||
|
|
||||||
|
|
||||||
|
class Architecture(Scene):
|
||||||
|
"""Animates the full AstrAI system stack layer by layer."""
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
title = Text("AstrAI Architecture", font_size=48, color=BLUE)
|
||||||
|
self.play(Write(title))
|
||||||
|
self.wait(0.2)
|
||||||
|
self.play(title.animate.to_edge(UP))
|
||||||
|
|
||||||
|
layers_data = [
|
||||||
|
(0.9, GREEN, "API Layer", ["FastAPI Server • OpenAI-Compatible API"]),
|
||||||
|
(0.9, BLUE, "Inference Engine", ["Streaming • Async • Batch Modes"]),
|
||||||
|
(1.6, YELLOW, "Continuous Batching Scheduler",
|
||||||
|
["Cleanup → Refill → Prefill → Decode",
|
||||||
|
"Position-Grouped Decode",
|
||||||
|
"Bitmask O(1) Slot Allocation"]),
|
||||||
|
(1.2, ORANGE, "Prefix Cache + KV Cache",
|
||||||
|
["Radix Tree • Slot Versioning",
|
||||||
|
"GPU copy_() → Zero-Copy Reuse"]),
|
||||||
|
(1.2, PURPLE, "Transformer Model (1B params)",
|
||||||
|
["24-layer GQA • RoPE • SwiGLU",
|
||||||
|
"bfloat16 • 100K vocab"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
layers = VGroup()
|
||||||
|
for height, color, label, subs in layers_data:
|
||||||
|
box = Rectangle(width=7.5, height=height, color=color, fill_opacity=0.1)
|
||||||
|
lbl = Text(label, font_size=18, color=color)
|
||||||
|
items = [lbl] + [Text(s, font_size=11, color=WHITE) for s in subs]
|
||||||
|
content = VGroup(*items)
|
||||||
|
content.arrange(DOWN, buff=0.22)
|
||||||
|
content.move_to(box.get_center())
|
||||||
|
layers.add(VGroup(box, content))
|
||||||
|
|
||||||
|
layers.arrange(DOWN, buff=0.18)
|
||||||
|
layers.next_to(title, DOWN, buff=0.3)
|
||||||
|
|
||||||
|
for i in range(len(layers)):
|
||||||
|
self.play(Create(layers[i]), run_time=0.35)
|
||||||
|
if i > 0:
|
||||||
|
# Use box-to-box for arrow endpoints (not content)
|
||||||
|
prev_box = layers[i - 1][0]
|
||||||
|
curr_box = layers[i][0]
|
||||||
|
arrow = Arrow(
|
||||||
|
prev_box.get_bottom(),
|
||||||
|
curr_box.get_top(),
|
||||||
|
color=GRAY,
|
||||||
|
buff=0.1,
|
||||||
|
max_tip_length_to_length_ratio=0.15,
|
||||||
|
)
|
||||||
|
self.play(Create(arrow), run_time=0.15)
|
||||||
|
|
||||||
|
self.wait(0.5)
|
||||||
|
|
||||||
|
hl = SurroundingRectangle(layers[3], color=GREEN, buff=0.12)
|
||||||
|
hl_note = Text("Zero-Copy Prefix Reuse", font_size=22, color=GREEN)
|
||||||
|
hl_note.next_to(hl, RIGHT, buff=0.8)
|
||||||
|
self.play(Create(hl), Write(hl_note))
|
||||||
|
self.wait(1.5)
|
||||||
|
self.play(FadeOut(hl), FadeOut(hl_note))
|
||||||
|
|
||||||
|
self.play(FadeOut(layers))
|
||||||
|
|
||||||
|
cta = VGroup(
|
||||||
|
Text("AstrAI", font_size=52, color=BLUE),
|
||||||
|
Text("Single GPU • Open Source • 1B params", font_size=24, color=GRAY),
|
||||||
|
Text("github.com/ViperEkura/AstrAI", font_size=20, color=YELLOW),
|
||||||
|
).arrange(DOWN, buff=0.35)
|
||||||
|
self.play(Write(cta))
|
||||||
|
self.wait(2)
|
||||||
|
self.play(FadeOut(cta), FadeOut(title))
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
"""AstrAI promo: Continuous Batching animation.
|
||||||
|
|
||||||
|
Shows how tasks flow through the 4-phase pipeline and get batched together.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from manim import *
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousBatching(Scene):
|
||||||
|
"""Animates tasks flowing through the prefill->decode pipeline."""
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
# ── title ──
|
||||||
|
title = Text("Continuous Batching", font_size=48, color=BLUE)
|
||||||
|
self.play(Write(title))
|
||||||
|
self.wait(0.5)
|
||||||
|
self.play(title.animate.to_edge(UP).scale(0.6))
|
||||||
|
top_bar = Line(LEFT * 7, RIGHT * 7, color=GRAY).next_to(title, DOWN)
|
||||||
|
self.play(Create(top_bar))
|
||||||
|
|
||||||
|
# ── pipeline stages ──
|
||||||
|
stage_names = ["Waiting\nQueue", "Prefill", "Decode\n(Batched)", "Finished"]
|
||||||
|
stage_color = [GRAY, BLUE, YELLOW, GREEN]
|
||||||
|
|
||||||
|
stages = VGroup()
|
||||||
|
arrows = VGroup()
|
||||||
|
for i, (name, color) in enumerate(zip(stage_names, stage_color)):
|
||||||
|
box = Rectangle(height=1.5, width=2.5, color=color, fill_opacity=0.12)
|
||||||
|
lbl = Text(name, font_size=18, color=color)
|
||||||
|
grp = VGroup(box, lbl)
|
||||||
|
grp.shift(RIGHT * (i - 1.5) * 3.2 + DOWN * 0.5)
|
||||||
|
stages.add(grp)
|
||||||
|
self.play(Create(grp), run_time=0.35)
|
||||||
|
if i > 0:
|
||||||
|
a = Arrow(stages[i - 1].get_right(), stages[i].get_left(), color=GRAY)
|
||||||
|
arrows.add(a)
|
||||||
|
self.play(Create(a), run_time=0.2)
|
||||||
|
|
||||||
|
pipeline = VGroup(stages, arrows)
|
||||||
|
plabel = Text("4-Phase Generation Loop", font_size=16, color=GRAY).next_to(
|
||||||
|
pipeline, DOWN, buff=0.4
|
||||||
|
)
|
||||||
|
self.play(Write(plabel))
|
||||||
|
self.wait(0.5)
|
||||||
|
|
||||||
|
# ── spawn tasks ──
|
||||||
|
task_colors = [YELLOW, ORANGE, PINK, TEAL, GREEN]
|
||||||
|
tasks = VGroup()
|
||||||
|
box_center = stages[0].get_center()
|
||||||
|
for i, c in enumerate(task_colors):
|
||||||
|
dot = Dot(color=c, radius=0.12)
|
||||||
|
y_off = (i - 2) * 0.2
|
||||||
|
dot.move_to(box_center + RIGHT * y_off * 0.3)
|
||||||
|
lbl = Text(f"R{i+1}", font_size=10, color=c).next_to(dot, UP, buff=0.1)
|
||||||
|
tg = VGroup(dot, lbl)
|
||||||
|
tasks.add(tg)
|
||||||
|
self.play(FadeIn(tg, scale=0.5), run_time=0.12)
|
||||||
|
|
||||||
|
self.wait(0.3)
|
||||||
|
|
||||||
|
# ── animate through stages ──
|
||||||
|
for phase in range(1, 4):
|
||||||
|
target = stages[phase].get_center()
|
||||||
|
anims = [t.animate.move_to(target) for t in tasks]
|
||||||
|
self.play(*anims, run_time=0.5, rate_func=smooth)
|
||||||
|
self.wait(0.15)
|
||||||
|
|
||||||
|
# ── highlight decode batching ──
|
||||||
|
ring = SurroundingRectangle(stages[2], color=YELLOW, buff=0.12)
|
||||||
|
note = Text(
|
||||||
|
"Same-position batch decoding", font_size=16, color=YELLOW
|
||||||
|
).next_to(stages[2], DOWN, buff=0.5)
|
||||||
|
self.play(Create(ring), Write(note))
|
||||||
|
self.wait(1)
|
||||||
|
self.play(FadeOut(ring), FadeOut(note))
|
||||||
|
|
||||||
|
# ── throughput comparison (text) ──
|
||||||
|
self.play(
|
||||||
|
*[FadeOut(t) for t in tasks],
|
||||||
|
FadeOut(pipeline),
|
||||||
|
FadeOut(plabel),
|
||||||
|
FadeOut(top_bar),
|
||||||
|
)
|
||||||
|
|
||||||
|
compare = VGroup(
|
||||||
|
Text("Throughput Comparison", font_size=32, color=BLUE),
|
||||||
|
Text(
|
||||||
|
"Static Batch: 1.0× (baseline)",
|
||||||
|
font_size=24, color=RED,
|
||||||
|
),
|
||||||
|
Text(
|
||||||
|
"Continuous Batching: 3.4× (single GPU)",
|
||||||
|
font_size=24, color=GREEN,
|
||||||
|
),
|
||||||
|
).arrange(DOWN, buff=0.4, aligned_edge=LEFT)
|
||||||
|
self.play(Write(compare))
|
||||||
|
self.wait(2)
|
||||||
|
self.play(FadeOut(compare))
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""AstrAI promo: Prefix Cache animation (Radix tree with branches)."""
|
||||||
|
|
||||||
|
from manim import *
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixCache(Scene):
|
||||||
|
"""Animates the radix-tree prefix cache with multiple distinct branches."""
|
||||||
|
|
||||||
|
def _add_node(self, parent_pos, label, color, dx, dy):
|
||||||
|
pos = parent_pos + np.array([dx, dy, 0])
|
||||||
|
dot = Dot(point=pos, color=color, radius=0.1)
|
||||||
|
txt = Text(label, font_size=13, color=color)
|
||||||
|
txt.next_to(dot, UP, buff=0.1)
|
||||||
|
grp = VGroup(dot, txt)
|
||||||
|
edge = Line(parent_pos, pos, color=GRAY, stroke_width=1.5)
|
||||||
|
return grp, edge, pos
|
||||||
|
|
||||||
|
def _add_leaf(self, parent_pos, color, tag):
|
||||||
|
leaf = Square(side_length=0.25, color=color, fill_opacity=0.4)
|
||||||
|
leaf.move_to(parent_pos + DOWN * 0.7)
|
||||||
|
edge = Line(parent_pos, leaf.get_top(), color=color, stroke_width=1.5)
|
||||||
|
lbl = Text(tag, font_size=10, color=color).next_to(leaf, DOWN, buff=0.1)
|
||||||
|
return VGroup(leaf, edge, lbl)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
title = Text("Prefix Cache", font_size=48, color=BLUE)
|
||||||
|
self.play(Write(title))
|
||||||
|
self.wait(0.2)
|
||||||
|
self.play(title.animate.to_edge(UP).scale(0.6))
|
||||||
|
|
||||||
|
# Root at top-left, tree stays visible throughout
|
||||||
|
root_pos = np.array([-4.5, 2.0, 0])
|
||||||
|
root = Circle(radius=0.25, color=BLUE, fill_opacity=0.2)
|
||||||
|
root.move_to(root_pos)
|
||||||
|
root_lbl = Text("root", font_size=10, color=GRAY).move_to(root)
|
||||||
|
root_grp = VGroup(root, root_lbl)
|
||||||
|
self.play(FadeIn(root_grp, scale=0.5), run_time=0.3)
|
||||||
|
|
||||||
|
# Labels accumulate on the right side
|
||||||
|
right_x = 3.5
|
||||||
|
label_y = 2.5
|
||||||
|
label_step = 0.5
|
||||||
|
|
||||||
|
def show_label(text, color):
|
||||||
|
nonlocal label_y
|
||||||
|
lbl = Text(text, font_size=14, color=color)
|
||||||
|
lbl.move_to([right_x, label_y, 0])
|
||||||
|
label_y -= label_step
|
||||||
|
self.play(Write(lbl))
|
||||||
|
return lbl
|
||||||
|
|
||||||
|
# ── R1: A → B → C ──
|
||||||
|
r1_lbl = show_label('R1: "A B C"', GREEN)
|
||||||
|
|
||||||
|
a_grp, a_edge, a_pos = self._add_node(root_pos, "A", GREEN, 0.6, -0.9)
|
||||||
|
self.play(Create(a_edge), FadeIn(a_grp, scale=0.5), run_time=0.2)
|
||||||
|
b_grp, b_edge, b_pos = self._add_node(a_pos, "B", GREEN, 0.6, -0.9)
|
||||||
|
self.play(Create(b_edge), FadeIn(b_grp, scale=0.5), run_time=0.2)
|
||||||
|
c_grp, c_edge, c_pos = self._add_node(b_pos, "C", GREEN, 0.6, -0.9)
|
||||||
|
self.play(Create(c_edge), FadeIn(c_grp, scale=0.5), run_time=0.2)
|
||||||
|
self.play(FadeIn(self._add_leaf(c_pos, GREEN, "slot 0"), scale=0.8), run_time=0.3)
|
||||||
|
self.wait(0.3)
|
||||||
|
|
||||||
|
# ── R2: shares A B, branches D E ──
|
||||||
|
r2_lbl = show_label('R2: "A B D E"', ORANGE)
|
||||||
|
|
||||||
|
for g in [a_grp, b_grp]:
|
||||||
|
flash = SurroundingRectangle(g, color=YELLOW, buff=0.12)
|
||||||
|
self.play(Create(flash), run_time=0.1)
|
||||||
|
self.play(FadeOut(flash), run_time=0.08)
|
||||||
|
|
||||||
|
d_grp, d_edge, d_pos = self._add_node(b_pos, "D", ORANGE, -0.6, -0.9)
|
||||||
|
self.play(Create(d_edge), FadeIn(d_grp, scale=0.5), run_time=0.2)
|
||||||
|
e_grp, e_edge, e_pos = self._add_node(d_pos, "E", ORANGE, -0.6, -0.9)
|
||||||
|
self.play(Create(e_edge), FadeIn(e_grp, scale=0.5), run_time=0.2)
|
||||||
|
self.play(FadeIn(self._add_leaf(e_pos, ORANGE, "slot 1"), scale=0.8), run_time=0.3)
|
||||||
|
self.wait(0.3)
|
||||||
|
|
||||||
|
# ── R3: shares A B, single F ──
|
||||||
|
r3_lbl = show_label('R3: "A B F"', PINK)
|
||||||
|
|
||||||
|
f_grp, f_edge, f_pos = self._add_node(b_pos, "F", PINK, 0.0, -1.2)
|
||||||
|
self.play(Create(f_edge), FadeIn(f_grp, scale=0.5), run_time=0.2)
|
||||||
|
self.play(FadeIn(self._add_leaf(f_pos, PINK, "slot 2"), scale=0.8), run_time=0.3)
|
||||||
|
self.wait(0.3)
|
||||||
|
|
||||||
|
# ── R4: new prefix from root ──
|
||||||
|
r4_lbl = show_label('R4: "X Y"', TEAL)
|
||||||
|
|
||||||
|
x_grp, x_edge, x_pos = self._add_node(root_pos, "X", TEAL, -1.0, -0.9)
|
||||||
|
self.play(Create(x_edge), FadeIn(x_grp, scale=0.5), run_time=0.2)
|
||||||
|
y_grp, y_edge, y_pos = self._add_node(x_pos, "Y", TEAL, -0.6, -0.9)
|
||||||
|
self.play(Create(y_edge), FadeIn(y_grp, scale=0.5), run_time=0.2)
|
||||||
|
self.play(FadeIn(self._add_leaf(y_pos, TEAL, "slot 3"), scale=0.8), run_time=0.3)
|
||||||
|
self.wait(0.5)
|
||||||
|
|
||||||
|
# ── highlight shared prefix (tree stays) ──
|
||||||
|
reuse_box = SurroundingRectangle(VGroup(a_grp, b_grp), color=YELLOW, buff=0.15)
|
||||||
|
reuse_note = Text(
|
||||||
|
'Prefix "A B" shared\nby 3 requests — 0 copy',
|
||||||
|
font_size=16,
|
||||||
|
color=YELLOW,
|
||||||
|
)
|
||||||
|
reuse_note.next_to(reuse_box, LEFT, buff=1.0)
|
||||||
|
self.play(Create(reuse_box), Write(reuse_note))
|
||||||
|
self.wait(2)
|
||||||
|
self.play(FadeOut(reuse_box), FadeOut(reuse_note))
|
||||||
|
|
||||||
|
# ── summary below tree (tree stays visible) ──
|
||||||
|
summary = VGroup(
|
||||||
|
Text("KV cache reuse across requests", font_size=26, color=GREEN),
|
||||||
|
Text("First-token latency: up to 50% reduction", font_size=18, color=GRAY),
|
||||||
|
).arrange(DOWN, buff=0.2)
|
||||||
|
summary.to_edge(DOWN, buff=0.5)
|
||||||
|
self.play(Write(summary))
|
||||||
|
self.wait(2)
|
||||||
|
self.play(FadeOut(summary), FadeOut(root_grp), FadeOut(title))
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""Render all promo scenes with Manim."""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
SCENES = [
|
||||||
|
("transformer.py", "Transformer"),
|
||||||
|
("architecture.py", "Architecture"),
|
||||||
|
("continuous_batching.py", "ContinuousBatching"),
|
||||||
|
("prefix_cache.py", "PrefixCache"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def render(file_name, scene_name, quality="-qh"):
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"manim",
|
||||||
|
f"promo/{file_name}",
|
||||||
|
scene_name,
|
||||||
|
quality,
|
||||||
|
"--media_dir",
|
||||||
|
"promo/output",
|
||||||
|
]
|
||||||
|
print(f"Rendering {scene_name}...")
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
print(f" Done → promo/output/{scene_name}.mp4")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
quality = "-qh" # 1080p; use -l for draft, -4k for ultra
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
quality = sys.argv[1]
|
||||||
|
for f, s in SCENES:
|
||||||
|
render(f, s, quality)
|
||||||
|
print("All scenes rendered.")
|
||||||
|
|
@ -0,0 +1,229 @@
|
||||||
|
"""AstrAI promo: Transformer GQA attention animation.
|
||||||
|
|
||||||
|
Shows the Grouped-Query Attention (GQA) mechanism with orthogonal data-flow lines:
|
||||||
|
Input → Q/K/V Projections → Repeat KV → SDPA → O Projection → Output
|
||||||
|
"""
|
||||||
|
|
||||||
|
from manim import *
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(Scene):
|
||||||
|
"""Animates the GQA attention mechanism with orthogonal connection lines."""
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
title = Text("Grouped-Query Attention (GQA)", font_size=42, color=BLUE)
|
||||||
|
title.to_edge(UP, buff=0.35)
|
||||||
|
self.play(Write(title))
|
||||||
|
|
||||||
|
# ── Helper: box ──
|
||||||
|
def mk(name, color, w=2.6, h=0.72, fs=10):
|
||||||
|
box = Rectangle(
|
||||||
|
width=w, height=h, color=color, fill_opacity=0.12, stroke_width=1.5
|
||||||
|
)
|
||||||
|
lbl = Text(name, font_size=fs, color=color)
|
||||||
|
return VGroup(box, lbl)
|
||||||
|
|
||||||
|
# ── Layout ──
|
||||||
|
inp = Text("x (hidden states)", font_size=15, color=GRAY)
|
||||||
|
inp.move_to(UP * 2.8)
|
||||||
|
|
||||||
|
y1 = 1.5
|
||||||
|
q_grp = mk("Q Projection\n1536 → 24×64", YELLOW)
|
||||||
|
k_grp = mk("K Projection\n1536 → 4×64", YELLOW)
|
||||||
|
v_grp = mk("V Projection\n1536 → 4×64", YELLOW)
|
||||||
|
q_grp.move_to(LEFT * 3.0 + UP * y1)
|
||||||
|
k_grp.move_to(UP * y1)
|
||||||
|
v_grp.move_to(RIGHT * 3.0 + UP * y1)
|
||||||
|
|
||||||
|
y2 = 0.0
|
||||||
|
repeat_grp = mk("Repeat KV\n4 heads → 24 heads", GREEN, 2.4, 0.68, 10)
|
||||||
|
repeat_grp.move_to(UP * y2)
|
||||||
|
|
||||||
|
y3 = -1.6
|
||||||
|
sdpa_grp = mk(
|
||||||
|
"Scaled Dot-Product\nAttention Q·Kᵀ/√d", BLUE, 2.8, 0.74, 10
|
||||||
|
)
|
||||||
|
sdpa_grp.move_to(UP * y3)
|
||||||
|
|
||||||
|
y4 = -3.0
|
||||||
|
o_grp = mk("O Projection\n1536 → 1536", PURPLE, 2.2, 0.68, 10)
|
||||||
|
o_grp.move_to(UP * y4)
|
||||||
|
|
||||||
|
out = Text("x' (hidden states)", font_size=15, color=GRAY)
|
||||||
|
out.next_to(o_grp, DOWN, buff=0.4)
|
||||||
|
|
||||||
|
# ── Animate boxes ──
|
||||||
|
self.play(Write(inp))
|
||||||
|
all_boxes = [q_grp, k_grp, v_grp, repeat_grp, sdpa_grp, o_grp]
|
||||||
|
for g in all_boxes:
|
||||||
|
self.play(FadeIn(g, shift=UP * 0.1), run_time=0.2)
|
||||||
|
|
||||||
|
# ── Input trunk → branch → Q/K/V (enter from directly above) ──
|
||||||
|
trunk_bottom = np.array([0, q_grp.get_top()[1] + 0.35, 0])
|
||||||
|
trunk = Line(inp.get_bottom(), trunk_bottom, color=GRAY, stroke_width=1.5)
|
||||||
|
self.play(Create(trunk), run_time=0.15)
|
||||||
|
|
||||||
|
branch_left = Line(
|
||||||
|
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
|
||||||
|
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
branch_right = Line(
|
||||||
|
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
|
||||||
|
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
self.play(Create(branch_left), Create(branch_right), run_time=0.2)
|
||||||
|
|
||||||
|
drop_q = Line(
|
||||||
|
np.array([q_grp.get_top()[0], trunk_bottom[1], 0]),
|
||||||
|
q_grp.get_top(),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
drop_k = Line(
|
||||||
|
np.array([k_grp.get_top()[0], trunk_bottom[1], 0]),
|
||||||
|
k_grp.get_top(),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
drop_v = Line(
|
||||||
|
np.array([v_grp.get_top()[0], trunk_bottom[1], 0]),
|
||||||
|
v_grp.get_top(),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
for ln in [drop_q, drop_k, drop_v]:
|
||||||
|
self.play(Create(ln), run_time=0.12)
|
||||||
|
|
||||||
|
input_lines = VGroup(trunk, branch_left, branch_right, drop_q, drop_k, drop_v)
|
||||||
|
|
||||||
|
# ── K/V → Repeat KV (trunk-branch, enter from above) ──
|
||||||
|
kv_junc_y = repeat_grp.get_top()[1] + 0.3
|
||||||
|
drop_k2 = Line(
|
||||||
|
k_grp.get_bottom(),
|
||||||
|
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
drop_v2 = Line(
|
||||||
|
v_grp.get_bottom(),
|
||||||
|
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
kv_branch = Line(
|
||||||
|
np.array([v_grp.get_bottom()[0], kv_junc_y, 0]),
|
||||||
|
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
kv_trunk = Line(
|
||||||
|
np.array([k_grp.get_bottom()[0], kv_junc_y, 0]),
|
||||||
|
repeat_grp.get_top(),
|
||||||
|
color=GRAY, stroke_width=1.5,
|
||||||
|
)
|
||||||
|
kv_lines = VGroup(drop_k2, drop_v2, kv_branch, kv_trunk)
|
||||||
|
self.play(Create(kv_lines), run_time=0.3)
|
||||||
|
|
||||||
|
# ── Q → SDPA (bypasses Repeat KV, from above) ──
|
||||||
|
qs_junc_y = sdpa_grp.get_top()[1] + 0.3
|
||||||
|
line_qs = VMobject(color=GRAY, stroke_width=1.5)
|
||||||
|
line_qs.set_points_as_corners([
|
||||||
|
q_grp.get_bottom(),
|
||||||
|
np.array([q_grp.get_bottom()[0], qs_junc_y, 0]),
|
||||||
|
np.array([sdpa_grp.get_top()[0], qs_junc_y, 0]),
|
||||||
|
sdpa_grp.get_top(),
|
||||||
|
])
|
||||||
|
self.play(Create(line_qs), run_time=0.15)
|
||||||
|
|
||||||
|
line_rs = orth_line(repeat_grp.get_bottom(), sdpa_grp.get_top(), GRAY)
|
||||||
|
self.play(Create(line_rs), run_time=0.15)
|
||||||
|
|
||||||
|
line_so = orth_line(sdpa_grp.get_bottom(), o_grp.get_top(), GRAY)
|
||||||
|
self.play(Create(line_so), run_time=0.15)
|
||||||
|
|
||||||
|
line_oo = orth_line(o_grp.get_bottom(), out.get_top(), GRAY)
|
||||||
|
self.play(Create(line_oo), run_time=0.15)
|
||||||
|
self.play(Write(out))
|
||||||
|
|
||||||
|
self.wait(0.4)
|
||||||
|
|
||||||
|
all_lines = VGroup(
|
||||||
|
input_lines, kv_lines, line_qs,
|
||||||
|
line_rs, line_so, line_oo,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── RoPE highlight ──
|
||||||
|
rope_q = SurroundingRectangle(q_grp, color=TEAL, buff=0.12)
|
||||||
|
rope_k = SurroundingRectangle(k_grp, color=TEAL, buff=0.12)
|
||||||
|
rope_t = Text(
|
||||||
|
"RoPE: rotary position encoding\napplied to Q and K",
|
||||||
|
font_size=13, color=TEAL,
|
||||||
|
)
|
||||||
|
rope_t.next_to(VGroup(rope_q, rope_k), UP, buff=0.25)
|
||||||
|
self.play(Create(rope_q), Create(rope_k), Write(rope_t))
|
||||||
|
self.wait(1.5)
|
||||||
|
self.play(FadeOut(rope_q), FadeOut(rope_k), FadeOut(rope_t))
|
||||||
|
|
||||||
|
# ── GQA ratio highlight ──
|
||||||
|
gqa_h = SurroundingRectangle(
|
||||||
|
VGroup(q_grp, k_grp, v_grp), color=YELLOW, buff=0.2
|
||||||
|
)
|
||||||
|
gqa_t = Text(
|
||||||
|
"GQA 6:1 — 24 Q-heads → 4 KV-heads\nKV cache reduced by 83%",
|
||||||
|
font_size=13, color=YELLOW,
|
||||||
|
)
|
||||||
|
gqa_t.next_to(gqa_h, RIGHT, buff=0.5)
|
||||||
|
self.play(Create(gqa_h), Write(gqa_t))
|
||||||
|
self.wait(1.8)
|
||||||
|
|
||||||
|
# ── Repeat KV highlight ──
|
||||||
|
kv_h = SurroundingRectangle(
|
||||||
|
VGroup(k_grp, v_grp), color=GREEN, buff=0.12
|
||||||
|
)
|
||||||
|
kv_t = Text(
|
||||||
|
"repeat_kv(): broadcast\n4 heads → 24 heads",
|
||||||
|
font_size=12, color=GREEN,
|
||||||
|
)
|
||||||
|
kv_t.next_to(kv_h, RIGHT, buff=0.5)
|
||||||
|
self.play(Create(kv_h), Write(kv_t))
|
||||||
|
self.wait(1.5)
|
||||||
|
|
||||||
|
# ── Fade all ──
|
||||||
|
self.play(
|
||||||
|
*[FadeOut(g) for g in all_boxes],
|
||||||
|
FadeOut(all_lines),
|
||||||
|
FadeOut(kv_h), FadeOut(kv_t),
|
||||||
|
FadeOut(gqa_h), FadeOut(gqa_t),
|
||||||
|
FadeOut(inp), FadeOut(out), FadeOut(title),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Specs card ──
|
||||||
|
st = Text("Model Specifications", font_size=36, color=BLUE)
|
||||||
|
st.to_edge(UP, buff=0.5)
|
||||||
|
rows_data = [
|
||||||
|
("Parameters", "~1.0B"),
|
||||||
|
("Layers", "24 × DecoderBlock"),
|
||||||
|
("Hidden Dim", "1536"),
|
||||||
|
("Q Heads / KV Heads", "24 / 4 (GQA, 6:1)"),
|
||||||
|
("Head Dim", "64"),
|
||||||
|
("FFN Dim", "4608 (SwiGLU)"),
|
||||||
|
("Max Length", "2048"),
|
||||||
|
("Precision", "bfloat16"),
|
||||||
|
]
|
||||||
|
table = VGroup()
|
||||||
|
for label, value in rows_data:
|
||||||
|
row = VGroup(
|
||||||
|
Text(label + ":", font_size=15, color=GRAY),
|
||||||
|
Text(value, font_size=15, color=WHITE),
|
||||||
|
).arrange(RIGHT, buff=0.4, aligned_edge=LEFT)
|
||||||
|
table.add(row)
|
||||||
|
table.arrange(DOWN, buff=0.1, aligned_edge=LEFT)
|
||||||
|
table.next_to(st, DOWN, buff=0.4)
|
||||||
|
self.play(Write(st), Write(table))
|
||||||
|
self.wait(2)
|
||||||
|
self.play(FadeOut(st), FadeOut(table))
|
||||||
|
|
||||||
|
|
||||||
|
def orth_line(start, end, color=GRAY):
|
||||||
|
"""Create an L-shaped orthogonal line from start to end."""
|
||||||
|
mid = np.array([start[0], end[1], 0])
|
||||||
|
path = VMobject(color=color, stroke_width=1.5)
|
||||||
|
path.set_points_as_corners([start, mid, end])
|
||||||
|
return path
|
||||||
Loading…
Reference in New Issue