Logo

A lightweight Transformer training & inference framework

python license release stars forks

English中文Issue TrackerDiscussionsHuggingFace

## 📖 Table of Contents - [Features](#features) - [Quick Start](#quick-start) - [Documentation](#documentation) - [Contributing](#contributing) - [Community](#community) - [License](#license) --- ## English ### Features - 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization. - 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures. - 💡 **Easy to Use**: Simple API with comprehensive examples and demos. - 📦 **Lightweight**: Minimal dependencies, easy to deploy. - 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas. - 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets. - 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box. ### Quick Start #### Installation ```bash git clone https://github.com/ViperEkura/AstrAI.git cd AstrAI pip install -e . ``` For development dependencies: ```bash pip install -e ".[dev]" ``` #### Train a Model ```bash python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model ``` | Parameter | Description | Default | |-----------|-------------|---------| | `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required | | `--data_root_path` | Dataset root directory | required | | `--param_path` | Model / checkpoint path | required | | `--n_epoch` | Training epochs | 1 | | `--batch_size` | Batch size | 1 | | `--accumulation_steps` | Gradient accumulation steps | 1 | | `--warmup_steps` | LR warmup steps | 1000 | | `--max_lr` | Peak learning rate (cosine decay) | 3e-4 | | `--max_grad_norm` | Max gradient norm for clipping | 1.0 | | `--adamw_beta1` | AdamW beta1 | 0.9 | | `--adamw_beta2` | AdamW beta2 | 0.95 | | `--adamw_weight_decay` | AdamW weight decay | 0.01 | | `--random_seed` | Random seed | 3407 | | `--num_workers` | DataLoader workers | 4 | | `--window_size` | Max input sequence length | auto | | `--stride` | Sequence stride | auto | | `--label_smoothing` | Label smoothing for cross entropy | 0.1 | | `--dpo_beta` | DPO beta | 0.1 | | `--grpo_clip_eps` | GRPO clip epsilon | 0.2 | | `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | | `--group_size` | GRPO group size | 4 | | `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 | | `--ckpt_interval` | Checkpoint interval (iters) | 5000 | | `--ckpt_dir` | Checkpoint directory | checkpoint | | `--start_epoch` | Start epoch (for resume) | 0 | | `--start_batch` | Start batch (for resume) | 0 | | `--nprocs` | Number of GPUs | 1 | | `--device_type` | Device type | cuda | Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters). #### Generate Text ```bash python scripts/tools/generate.py --param_path=/path/to/param_path ``` #### Docker Build and run with Docker (recommended for GPU environments): ```bash # Build image docker build -t astrai:latest . # Run with GPU support docker run --gpus all -it astrai:latest # Run with specific GPUs docker run --gpus '"device=0,1"' -it astrai:latest # Run inference server docker run --gpus all -p 8000:8000 astrai:latest \ python -m scripts.tools.server --port 8000 --device cuda # Run with volume mount for data docker run --gpus all -v /path/to/data:/data -it astrai:latest # Docker Compose (GPU, default) docker compose up -d # Docker Compose (CPU only) docker compose --profile cpu up -d ``` > **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`. > **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`. #### Start HTTP Server Start the inference server with OpenAI and Anthropic-compatible HTTP API: ```bash python -m scripts.tools.server --port 8000 --device cuda ``` Make requests: ```bash # OpenAI-compatible curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 512 }' # OpenAI-compatible streaming curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [{"role": "user", "content": "Tell a story"}], "stream": true, "max_tokens": 500 }' # Anthropic-compatible curl -X POST http://localhost:8000/v1/messages \ -H "Content-Type: application/json" \ -d '{ "model": "astrai", "system": "You are a helpful assistant.", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 512 }' # Anthropic-compatible streaming with stop sequences curl -X POST http://localhost:8000/v1/messages \ -H "Content-Type: application/json" \ -d '{ "model": "astrai", "messages": [{"role": "user", "content": "Write a story"}], "max_tokens": 500, "stream": true, "stop_sequences": ["The end"] }' # Health check curl http://localhost:8000/health ``` #### Demo Check out the demos in the `scripts/demo/` folder: ```bash # Download pre‑processed data (required before running demos) python scripts/demo/download.py # Interactive streaming chat python scripts/demo/stream_chat.py # Batch generation python scripts/demo/generate_batch.py # Auto‑regressive generation python scripts/demo/generate_ar.py ``` Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd). ### Documentation | Document | Description | |----------|-------------| | [Parameter Guide](./assets/docs/params.md) | Training & inference parameters | | [Design Document](./assets/docs/design.md) | Framework architecture & module design | | [Data Flow](./assets/docs/dataflow.md) | Data processing pipeline details | | [Model Introduction](./assets/docs/introduction.md) | Model architecture & technical details | ### Contributing We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details. 1. Fork the repository. 2. Create a feature branch. 3. Commit your changes. 4. Open a Pull Request. For major changes, please open an issue first to discuss what you would like to change. ### Community - **GitHub Issues**: [Issue Tracker](https://github.com/ViperEkura/AstrAI/issues) - **Discussions**: [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) - **HuggingFace**: [Model Hub](https://huggingface.co/ViperEk) ### License This project is licensed under the [GPL-3.0 License](LICENSE). ---
A lightweight Transformer framework designed for both high performance and ease of use.