- train.py: 补上 --batch_size、--grpo_clip_eps,删除 3 处重复 --group_size - generate.py: --model_dir 改为 --param_path 对齐 README - automodel.py: from_pretrained 新增 strict 参数(默认 True) - parallel/setup.py: 修复 device_ids 索引越界 - train_callback.py: scheduler.step() 移至 on_step_end - test_train_strategy.py: 测试中补 optimizer.step() - engine.py: 非流式改为循环等待所有任务完成,补 remove_task 清理 - scheduler.py: Task 添加 _pages_freed 标志,杜绝双重释放 - trainer.py: accumulation_steps=0 时 clamp 为 1 - tokenizer.py: save_pretrained 添加 _tokenizer is None 检查 - benchmark.py: 修复 ModelConfig 过时 import 路径 - inference/__init__.py: 修复 stale docstring |
||
|---|---|---|
| .github | ||
| assets | ||
| astrai | ||
| scripts | ||
| tests | ||
| .dockerignore | ||
| .gitattributes | ||
| .gitignore | ||
| CONTRIBUTING.md | ||
| Dockerfile | ||
| LICENSE | ||
| README.md | ||
| docker-compose.yml | ||
| pyproject.toml | ||
README.md
A lightweight Transformer training & inference framework
📖 Table of Contents
English
Features
- 🚀 High Performance: Optimized for both training and inference with efficient parallelization.
- 🔧 Flexible: Support for seq/sft/dpo/grpo training, customizable model architectures.
- 💡 Easy to Use: Simple API with comprehensive examples and demos.
- 📦 Lightweight: Minimal dependencies, easy to deploy.
- 🔬 Research‑Friendly: Modular design, easy to experiment with new ideas.
- 🤗 HuggingFace Integration: Compatible with HuggingFace models and datasets.
- 🔌 Dual API Compatibility: Supports both OpenAI and Anthropic chat completion APIs out of the box.
Quick Start
Installation
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .
For development dependencies:
pip install -e ".[dev]"
Train a Model
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
| Parameter | Description | Default |
|---|---|---|
--train_type |
Training type (seq, sft, dpo, grpo) |
required |
--data_root_path |
Dataset root directory | required |
--param_path |
Model / checkpoint path | required |
--n_epoch |
Training epochs | 1 |
--batch_size |
Batch size | 1 |
--accumulation_steps |
Gradient accumulation steps | 1 |
--warmup_steps |
LR warmup steps | 1000 |
--max_lr |
Peak learning rate (cosine decay) | 3e-4 |
--max_grad_norm |
Max gradient norm for clipping | 1.0 |
--adamw_beta1 |
AdamW beta1 | 0.9 |
--adamw_beta2 |
AdamW beta2 | 0.95 |
--adamw_weight_decay |
AdamW weight decay | 0.01 |
--random_seed |
Random seed | 3407 |
--num_workers |
DataLoader workers | 4 |
--window_size |
Max input sequence length | auto |
--stride |
Sequence stride | auto |
--label_smoothing |
Label smoothing for cross entropy | 0.1 |
--dpo_beta |
DPO beta | 0.1 |
--grpo_clip_eps |
GRPO clip epsilon | 0.2 |
--grpo_kl_coef |
GRPO KL penalty coefficient | 0.01 |
--group_size |
GRPO group size | 4 |
--grpo_sync_interval |
GRPO ref model sync interval (steps) | 200 |
--ckpt_interval |
Checkpoint interval (iters) | 5000 |
--ckpt_dir |
Checkpoint directory | checkpoint |
--start_epoch |
Start epoch (for resume) | 0 |
--start_batch |
Start batch (for resume) | 0 |
--nprocs |
Number of GPUs | 1 |
--device_type |
Device type | cuda |
Full reference at Parameter Guide.
Generate Text
python scripts/tools/generate.py --param_path=/path/to/param_path
Docker
Build and run with Docker (recommended for GPU environments):
# Build image
docker build -t astrai:latest .
# Run with GPU support
docker run --gpus all -it astrai:latest
# Run with specific GPUs
docker run --gpus '"device=0,1"' -it astrai:latest
# Run inference server
docker run --gpus all -p 8000:8000 astrai:latest \
python -m scripts.tools.server --port 8000 --device cuda
# Run with volume mount for data
docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker Compose (GPU, default)
docker compose up -d
# Docker Compose (CPU only)
docker compose --profile cpu up -d
Note:
--gpus allis required for CUDA support. Without it,torch.cuda.is_available()will returnFalse.
Note:
--gpus allis required for CUDA support. Without it,torch.cuda.is_available()will returnFalse.
Start HTTP Server
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
python -m scripts.tools.server --port 8000 --device cuda
Make requests:
# OpenAI-compatible
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# OpenAI-compatible streaming
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Tell a story"}],
"stream": true,
"max_tokens": 500
}'
# Anthropic-compatible
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# Anthropic-compatible streaming with stop sequences
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["The end"]
}'
# Health check
curl http://localhost:8000/health
Demo
Check out the demos in the scripts/demo/ folder:
# Download pre‑processed data (required before running demos)
python scripts/demo/download.py
# Interactive streaming chat
python scripts/demo/stream_chat.py
# Batch generation
python scripts/demo/generate_batch.py
# Auto‑regressive generation
python scripts/demo/generate_ar.py
Watch a video walkthrough on bilibili.
Documentation
| Document | Description |
|---|---|
| Parameter Guide | Training & inference parameters |
| Design Document | Framework architecture & module design |
| Data Flow | Data processing pipeline details |
| Model Introduction | Model architecture & technical details |
Contributing
We welcome contributions! Please see our Contributing Guidelines for details.
- Fork the repository.
- Create a feature branch.
- Commit your changes.
- Open a Pull Request.
For major changes, please open an issue first to discuss what you would like to change.
Community
- GitHub Issues: Issue Tracker
- Discussions: GitHub Discussions
- HuggingFace: Model Hub
License
This project is licensed under the GPL-3.0 License.