7.2 KiB
7.2 KiB
📖 目录
中文
特性
- 🚀 高性能: 训练与推理双向优化,高效并行。
- 🔧 灵活: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。
- 💡 易用: 简洁的 API 与丰富的示例、演示。
- 📦 轻量: 依赖少,部署简单。
- 🔬 研究友好: 模块化设计,便于实验新想法。
- 🤗 HuggingFace 风格 API: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
- 🔌 双 API 兼容: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
快速上手
端到端演示,只需 5 步:
1. 安装
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .
# pip install -e ".[dev]" # 可选:开发依赖(pytest, ruff)
2. 下载模型
python scripts/demo/download.py # 下载 1B 检查点到 params/
3. 预处理数据
创建 pretrain.json(seq 策略的预处理配置):
{
"version": 1,
"input": {"sections": [{"field": "text", "action": "train"}]},
"preprocessing": {"max_seq_len": 2048},
"output": {"storage_format": "bin"}
}
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c pretrain.json
4. 训练
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
5. 启动服务并调用
# 终端 1:启动服务
python scripts/tools/server.py --param_path ./params --device cuda
# 终端 2:发起请求
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"你好"}],"max_tokens":512}'
演示
查看 scripts/demo/ 文件夹中的演示:
# 下载模型权重(运行演示前必需)
python scripts/demo/download.py # model → params/
# 交互式流式聊天(多轮对话,保持历史记录)
python scripts/demo/stream_chat.py
# 在 >> 后输入消息,输入 !exit 退出
# 批量生成(5 条硬编码提示词,非流式)
python scripts/demo/generate_batch.py
# 单条提示词自回归流式生成
python scripts/demo/generate_ar.py
所有生成演示默认使用 temperature=0.8、top_p=0.95、top_k=50、max_tokens=2048,需要 params/ 目录包含模型权重(请先运行 download.py)。
观看 bilibili 上的视频演示。
更多选项请参考文档。
文本生成
从 JSONL 文件批量生成:
python scripts/tools/generate.py \
--param_path ./params \
--input_json_file input.jsonl \
--output_json_file output.jsonl
Docker
使用 Docker 构建和运行(推荐用于 GPU 环境):
# 构建镜像
docker build -t astrai:latest .
# 启用 GPU 运行
docker run --gpus all -it astrai:latest
# 运行推理服务
docker run --gpus all -p 8000:8000 astrai:latest \
python -m scripts.tools.server --port 8000 --device cuda
# 挂载数据卷
docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker Compose(GPU,默认)
docker compose up -d
# Docker Compose(仅 CPU)
docker compose --profile cpu up -d
注意: 必须使用
--gpus all才能启用 CUDA 支持,否则torch.cuda.is_available()将返回False。
HTTP API 示例
除快速上手流程外,更多请求示例:
# OpenAI 兼容流式
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"讲个故事"}],"stream":true,"max_tokens":500}'
# Anthropic 兼容
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","system":"你是一个乐于助人的助手。","messages":[{"role":"user","content":"你好"}],"max_tokens":512}'
# Anthropic 兼容流式并设置停止序列
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","messages":[{"role":"user","content":"写个故事"}],"max_tokens":500,"stream":true,"stop_sequences":["结束"]}'
# 健康检查
curl http://localhost:8000/health
SSE 流式格式、错误码和统计端点详见推理文档。
文档
| 文档 | 说明 |
|---|---|
| CLI 参考 | 所有 CLI 工具参数(训练、服务、生成、预处理) |
| 架构文档 | 系统架构、类图与设计模式 |
| 训练文档 | 训练循环、策略与公式 |
| 推理文档 | KVCache、连续批处理、采样与 HTTP API |
| 数据流程 | 数据管道、存储后端与数据集架构 |
| 数据预处理 | 声明式 JSON 驱动数据预处理 |
贡献
我们欢迎贡献!请参阅贡献指南了解详情。
- Fork 本仓库。
- 创建功能分支。
- 提交更改。
- 发起 Pull Request。
重大更改请先开 issue 讨论。
社区
- GitHub Issues: 问题追踪
- Discussions: GitHub 讨论区
- HuggingFace: 模型中心
许可证
本项目采用 GPL-3.0 许可证。
专为高性能与易用性设计的轻量级 Transformer 框架。