7.6 KiB
A lightweight Transformer training & inference framework
📖 Table of Contents
English
Features
- 🚀 High Performance: Optimized for both training and inference with efficient parallelization.
- 🔧 Flexible: Support for seq/sft/dpo/grpo training, customizable model architectures.
- 💡 Easy to Use: Simple API with comprehensive examples and demos.
- 📦 Lightweight: Minimal dependencies, easy to deploy.
- 🔬 Research‑Friendly: Modular design, easy to experiment with new ideas.
- 🤗 HuggingFace-Style API: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
- 🔌 Dual API Compatibility: Supports both OpenAI and Anthropic chat completion APIs out of the box.
Getting Started
End-to-end walkthrough in 5 steps:
1. Install
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .
# pip install -e ".[dev]" # optional: dev dependencies (pytest, ruff)
2. Download model
python scripts/demo/download.py # downloads 1B checkpoint to params/
3. Preprocess data
Create pretrain.json (preprocessing config for seq strategy):
{
"version": 1,
"input": {"sections": [{"field": "text", "action": "train"}]},
"preprocessing": {"max_seq_len": 2048},
"output": {"storage_format": "bin"}
}
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c pretrain.json
4. Train
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
5. Serve & query
# Terminal 1: start server
python scripts/tools/server.py --param_path ./params --device cuda
# Terminal 2: query
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
Demo
Check out the demos in the scripts/demo/ folder:
# Download model weights (required before running demos)
python scripts/demo/download.py # model → params/
# Interactive streaming chat (multi-turn, maintains history)
python scripts/demo/stream_chat.py
# Type your message after >>, type !exit to quit
# Batch generation (5 hardcoded prompts, non-streaming)
python scripts/demo/generate_batch.py
# Single-prompt autoregressive streaming
python scripts/demo/generate_ar.py
All generation demos use temperature=0.8, top_p=0.95, top_k=50, max_tokens=2048 by default and require params/ to contain model weights (run download.py first).
Watch a video walkthrough on bilibili.
See Documentation for full references beyond the examples above.
Text Generation
Batch generation from a JSONL file:
python scripts/tools/generate.py \
--param_path ./params \
--input_json_file input.jsonl \
--output_json_file output.jsonl
Docker
Build and run with Docker (recommended for GPU environments):
# Build image
docker build -t astrai:latest .
# Run with GPU support
docker run --gpus all -it astrai:latest
# Run inference server
docker run --gpus all -p 8000:8000 astrai:latest \
python -m scripts.tools.server --port 8000 --device cuda
# Run with volume mount for data
docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker Compose (GPU, default)
docker compose up -d
# Docker Compose (CPU only)
docker compose --profile cpu up -d
Note:
--gpus allis required for CUDA support. Without it,torch.cuda.is_available()will returnFalse.
HTTP API Examples
Additional request examples beyond the Getting Started flow:
# OpenAI-compatible streaming
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Tell a story"}],"stream":true,"max_tokens":500}'
# Anthropic-compatible
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","system":"You are a helpful assistant.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
# Anthropic-compatible streaming with stop sequences
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","messages":[{"role":"user","content":"Write a story"}],"max_tokens":500,"stream":true,"stop_sequences":["The end"]}'
# Health check
curl http://localhost:8000/health
See Inference Guide for SSE streaming format, error codes, and stats endpoint.
Documentation
| Document | Description |
|---|---|
| CLI Reference | Parameters for all CLI tools (train, server, generate, preprocess) |
| Architecture | System architecture, class diagram & design patterns |
| Training | Training loop, strategies & formulas |
| Inference | KVCache, continuous batching, sampling & HTTP API |
| Data Flow | Data pipeline, storage backends & dataset architecture |
| Preprocessing | Declarative JSON-driven data preprocessing |
Contributing
We welcome contributions! Please see our Contributing Guidelines for details.
- Fork the repository.
- Create a feature branch.
- Commit your changes.
- Open a Pull Request.
For major changes, please open an issue first to discuss what you would like to change.
Community
- GitHub Issues: Issue Tracker
- Discussions: GitHub Discussions
- HuggingFace: Model Hub
License
This project is licensed under the GPL-3.0 License.