- refactor: 分页 KV cache(PagedCache+CacheView)替换固定 slot,删除 PrefixCache - refactor: 推理引擎控制逻辑重写,修复连续批处理核心缺陷、线程安全问题 - refactor: KV 缓存槽位下沉到注意力层,移除 _remap_kv / _writeback_kv - refactor: 统一采样路径为 SamplingPipeline batch tensor,删除 apply_sampling_strategies - refactor: 设计模式优化 inference 模块导入结构(cache/sampling 独立) - feat: 推理引擎前缀缓存(KV cache 复用) - feat: OpenAI 兼容 chat completion API(流式+非流式+usage) - feat: Anthropic 兼容 /v1/messages API,移除旧版 /generate 端点 - feat: GRPO CLI 接入 + on-policy,OpenAI API top_k 参数化 - feat: Checkpoint 支持 extra 通用扩展数据 - feat: Docker Compose 一键部署(GPU/CPU 双模式) - feat: GRPO 训练参数补充,批处理训练参数表 - fix: 调度器延迟优化 — 移除 5ms 睡眠,修复 refill 任务丢失 - fix: CLI 参数缺失/重复、device_ids 越界、generate 参数名不一致 - fix: 长对话截断方向错误,保留最新 token 而非最早 - fix: remove_task 未释放 KV cache slot 导致第二轮对话死锁 - fix: KV cache 槽位索引错位、版本校验缺失、注意力掩码 - fix: scheduler 越界 bug,SchedulerCallback 回调阶段修正 - perf: _Result 改用 Condition.wait_for 消除非流式 CPU 空转 - perf: decode 每步张量预分配;input_ids 改用一次构建代替逐元素赋值 - refactor: 移除 device_ids 参数,统一 CUDA_VISIBLE_DEVICES - docs: 更新文档以匹配分页 KV cache 等代码重构 - docs: 修正多处文档错误、补充训练参数说明 |
||
|---|---|---|
| .github | ||
| assets | ||
| astrai | ||
| scripts | ||
| tests | ||
| .dockerignore | ||
| .gitattributes | ||
| .gitignore | ||
| CONTRIBUTING.md | ||
| Dockerfile | ||
| LICENSE | ||
| README.md | ||
| docker-compose.yml | ||
| pyproject.toml | ||
README.md
A lightweight Transformer training & inference framework
📖 Table of Contents
English
Features
- 🚀 High Performance: Optimized for both training and inference with efficient parallelization.
- 🔧 Flexible: Support for seq/sft/dpo/grpo training, customizable model architectures.
- 💡 Easy to Use: Simple API with comprehensive examples and demos.
- 📦 Lightweight: Minimal dependencies, easy to deploy.
- 🔬 Research‑Friendly: Modular design, easy to experiment with new ideas.
- 🤗 HuggingFace-Style API: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
- 🔌 Dual API Compatibility: Supports both OpenAI and Anthropic chat completion APIs out of the box.
Quick Start
Installation
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .
For development dependencies:
pip install -e ".[dev]"
Train a Model
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
Full reference at Parameter Guide.
Generate Text
python scripts/tools/generate.py \
--param_path /path/to/model \
--input_json_file /path/to/input.json \
--output_json_file /path/to/output.json
Docker
Build and run with Docker (recommended for GPU environments):
# Build image
docker build -t astrai:latest .
# Run with GPU support
docker run --gpus all -it astrai:latest
# Run with specific GPUs
docker run --gpus '"device=0,1"' -it astrai:latest
# Run inference server
docker run --gpus all -p 8000:8000 astrai:latest \
python -m scripts.tools.server --port 8000 --device cuda
# Run with volume mount for data
docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker Compose (GPU, default)
docker compose up -d
# Docker Compose (CPU only)
docker compose --profile cpu up -d
Note:
--gpus allis required for CUDA support. Without it,torch.cuda.is_available()will returnFalse.
Start HTTP Server
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
python -m scripts.tools.server --port 8000 --device cuda
Make requests:
# OpenAI-compatible
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# OpenAI-compatible streaming
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Tell a story"}],
"stream": true,
"max_tokens": 500
}'
# Anthropic-compatible
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# Anthropic-compatible streaming with stop sequences
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["The end"]
}'
# Health check
curl http://localhost:8000/health
Demo
Check out the demos in the scripts/demo/ folder:
# Download pre‑processed data (required before running demos)
python scripts/demo/download.py
# Interactive streaming chat
python scripts/demo/stream_chat.py
# Batch generation
python scripts/demo/generate_batch.py
# Auto‑regressive generation
python scripts/demo/generate_ar.py
Watch a video walkthrough on bilibili.
Documentation
| Document | Description |
|---|---|
| Parameter Guide | Training & inference parameters |
| Design Document | Framework architecture & module design |
| Data Flow | Data processing pipeline details |
| Model Introduction | Model architecture & technical details |
Contributing
We welcome contributions! Please see our Contributing Guidelines for details.
- Fork the repository.
- Create a feature branch.
- Commit your changes.
- Open a Pull Request.
For major changes, please open an issue first to discuss what you would like to change.
Community
- GitHub Issues: Issue Tracker
- Discussions: GitHub Discussions
- HuggingFace: Model Hub
License
This project is licensed under the GPL-3.0 License.
A lightweight Transformer framework designed for both high performance and ease of use.