Logo
English中文

轻量级 Transformer 训练与推理框架

python license release stars forks

English中文问题追踪讨论区HuggingFace

## 📖 目录 - [特性](#特性) - [快速开始](#快速开始) - [文档](#文档) - [贡献](#贡献) - [社区](#社区) - [许可证](#许可证) --- ## 中文 ### 特性 - 🚀 **高性能**: 训练与推理双向优化,高效并行。 - 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。 - 💡 **易用**: 简洁的 API 与丰富的示例、演示。 - 📦 **轻量**: 依赖少,部署简单。 - 🔬 **研究友好**: 模块化设计,便于实验新想法。 - 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。 - 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。 ### 快速开始 #### 安装 ```bash git clone https://github.com/ViperEkura/AstrAI.git cd AstrAI pip install -e . ``` 安装开发依赖: ```bash pip install -e ".[dev]" ``` #### 训练模型 ```bash python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model ``` | 参数 | 说明 | 默认值 | |------|------|--------| | `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 | | `--data_root_path` | 数据集根目录 | 必填 | | `--param_path` | 模型参数或断点路径 | 必填 | | `--n_epoch` | 训练轮数 | 1 | | `--batch_size` | 批次大小 | 1 | | `--accumulation_steps` | 梯度累积步数 | 1 | | `--warmup_steps` | 预热步数 | 1000 | | `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 | | `--max_grad_norm` | 梯度裁剪最大值 | 1.0 | | `--adamw_beta1` | AdamW beta1 | 0.9 | | `--adamw_beta2` | AdamW beta2 | 0.95 | | `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 | | `--random_seed` | 随机种子 | 3407 | | `--num_workers` | 数据加载线程数 | 4 | | `--window_size` | 最大输入序列长度 | auto | | `--stride` | 序列步长 | auto | | `--label_smoothing` | 交叉熵标签平滑 | 0.1 | | `--dpo_beta` | DPO beta | 0.1 | | `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 | | `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 | | `--group_size` | GRPO 组大小 | 4 | | `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 | | `--ckpt_interval` | 检查点间隔(迭代步) | 5000 | | `--ckpt_dir` | 检查点保存目录 | checkpoint | | `--start_epoch` | 起始轮次(用于断点续训) | 0 | | `--start_batch` | 起始批次(用于断点续训) | 0 | | `--nprocs` | GPU 数量 | 1 | | `--device_type` | 设备类型 | cuda | 完整参数列表见[参数说明](./params.md#training-parameters)。 #### 文本生成 ```bash python scripts/tools/generate.py --param_path=/path/to/param_path ``` #### Docker 使用 Docker 构建和运行(推荐用于 GPU 环境): ```bash # 构建镜像 docker build -t astrai:latest . # 启用 GPU 运行 docker run --gpus all -it astrai:latest # 指定特定 GPU docker run --gpus '"device=0,1"' -it astrai:latest # 运行推理服务 docker run --gpus all -p 8000:8000 astrai:latest \ python -m scripts.tools.server --port 8000 --device cuda # 挂载数据卷 docker run --gpus all -v /path/to/data:/data -it astrai:latest # Docker Compose(GPU,默认) docker compose up -d # Docker Compose(仅 CPU) docker compose --profile cpu up -d ``` > **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`。 #### 启动 HTTP 服务 启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API: ```bash python -m scripts.tools.server --port 8000 --device cuda ``` 发起请求: ```bash # OpenAI 兼容 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [{"role": "user", "content": "你好"}], "max_tokens": 512 }' # OpenAI 兼容流式 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [{"role": "user", "content": "讲个故事"}], "stream": true, "max_tokens": 500 }' # Anthropic 兼容 curl -X POST http://localhost:8000/v1/messages \ -H "Content-Type: application/json" \ -d '{ "model": "astrai", "system": "你是一个乐于助人的助手。", "messages": [{"role": "user", "content": "你好"}], "max_tokens": 512 }' # Anthropic 兼容流式并设置停止序列 curl -X POST http://localhost:8000/v1/messages \ -H "Content-Type: application/json" \ -d '{ "model": "astrai", "messages": [{"role": "user", "content": "写个故事"}], "max_tokens": 500, "stream": true, "stop_sequences": ["结束"] }' # 健康检查 curl http://localhost:8000/health ``` #### 演示 查看 `scripts/demo/` 文件夹中的演示: ```bash # 下载预处理数据(运行演示前必需) python scripts/demo/download.py # 交互式流式聊天 python scripts/demo/stream_chat.py # 批量生成 python scripts/demo/generate_batch.py # 自回归生成 python scripts/demo/generate_ar.py ``` 观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。 ### 文档 | 文档 | 说明 | |------|------| | [参数说明](./params.md) | 训练与推理参数配置 | | [设计文档](./design.md) | 系统架构与模块设计 | | [数据流程](./dataflow.md) | 数据处理管道详解 | | [模型介绍](./introduction.md) | 模型架构与技术细节 | ### 贡献 我们欢迎贡献!请参阅[贡献指南](../../CONTRIBUTING.md)了解详情。 1. Fork 本仓库。 2. 创建功能分支。 3. 提交更改。 4. 发起 Pull Request。 重大更改请先开 issue 讨论。 ### 社区 - **GitHub Issues**: [问题追踪](https://github.com/ViperEkura/AstrAI/issues) - **Discussions**: [GitHub 讨论区](https://github.com/ViperEkura/AstrAI/discussions) - **HuggingFace**: [模型中心](https://huggingface.co/ViperEk) ### 许可证 本项目采用 [GPL-3.0 许可证](../../LICENSE)。 ---
专为高性能与易用性设计的轻量级 Transformer 框架。