Go to file
ViperEkura 9d96b0431d docs: 更新文档以匹配分页 KV cache 等代码重构 2026-05-08 22:41:13 +08:00
.github ci: 添加 Docker 镜像自动构建工作流 2026-04-10 13:09:58 +08:00
assets docs: 更新文档以匹配分页 KV cache 等代码重构 2026-05-08 22:41:13 +08:00
astrai feat: OpenAI 兼容的 chat completion API(流式+非流式+usage) 2026-05-08 21:54:55 +08:00
scripts fix: benchmark 改用 PagedCache 替代已删除的 persistent_key_values 2026-05-08 21:26:55 +08:00
tests feat: OpenAI 兼容的 chat completion API(流式+非流式+usage) 2026-05-08 21:54:55 +08:00
.dockerignore build: 修改docker 构建流程 2026-04-10 11:25:00 +08:00
.gitattributes ci: 优化 GitHub Actions 工作流 2026-04-05 22:40:16 +08:00
.gitignore chore: 增加docker 配置 2026-04-04 10:59:32 +08:00
CONTRIBUTING.md docs: 优化文档结构并添加 GitHub 模板 2026-03-31 10:00:49 +08:00
Dockerfile build: 修改docker 配置 2026-04-10 12:53:08 +08:00
LICENSE Change license from Apache 2.0 to GPL v3.0 2026-02-22 21:20:34 +08:00
README.md docs: 修正文档错误并补充训练参数说明 2026-05-08 18:07:57 +08:00
pyproject.toml fix: 修复工厂模式问题并增加chat-template设置 2026-04-04 12:05:05 +08:00

README.md

Logo

A lightweight Transformer training & inference framework

python license release stars forks


📖 Table of Contents


English

Features

  • 🚀 High Performance: Optimized for both training and inference with efficient parallelization.
  • 🔧 Flexible: Support for seq/sft/dpo/grpo training, customizable model architectures.
  • 💡 Easy to Use: Simple API with comprehensive examples and demos.
  • 📦 Lightweight: Minimal dependencies, easy to deploy.
  • 🔬 ResearchFriendly: Modular design, easy to experiment with new ideas.
  • 🤗 HuggingFace Integration: Compatible with HuggingFace models and datasets.

Quick Start

Installation

git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .

For development dependencies:

pip install -e ".[dev]"

Train a Model

python scripts/tools/train.py \
  --train_type=seq \
  --data_root_path=/path/to/dataset \
  --param_path=/path/to/model \
  --n_epoch=3 \
  --batch_size=4 \
  --accumulation_steps=8 \
  --max_lr=3e-4 \
  --warmup_steps=2000 \
  --ckpt_interval=5000 \
  --ckpt_dir=./checkpoints

Generate Text

python scripts/tools/generate.py --param_path=/path/to/param_path

Training Parameters

Parameter Description Default
--train_type Training type (seq, sft, dpo) required
--data_root_path Dataset root directory required
--param_path Model / checkpoint path required
--n_epoch Training epochs 1
--batch_size Batch size 1
--accumulation_steps Gradient accumulation steps 1
--max_lr Peak learning rate (cosine decay) 3e-4
--warmup_steps LR warmup steps 1000
--ckpt_interval Checkpoint interval (iters) 5000
--ckpt_dir Checkpoint directory checkpoint
--num_workers DataLoader workers 4
--nprocs Number of GPUs 1

Full reference at Parameter Guide.

Docker

Build and run with Docker (recommended for GPU environments):

# Build image
docker build -t astrai:latest .

# Run with GPU support
docker run --gpus all -it astrai:latest

# Run with specific GPUs
docker run --gpus '"device=0,1"' -it astrai:latest

# Run inference server
docker run --gpus all -p 8000:8000 astrai:latest \
  python -m scripts.tools.server --port 8000 --device cuda

# Run with volume mount for data
docker run --gpus all -v /path/to/data:/data -it astrai:latest

Note: --gpus all is required for CUDA support. Without it, torch.cuda.is_available() will return False.

Start HTTP Server

Start the inference server with OpenAI-compatible HTTP API:

python -m scripts.tools.server --port 8000 --device cuda

Make requests:

# Chat API (OpenAI compatible)
curl -X POST http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "messages": [{"role": "user", "content": "Hello"}],
    "max_tokens": 512
  }'

# Streaming response
curl -X POST http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "messages": [{"role": "user", "content": "Tell a story"}],
    "stream": true,
    "max_tokens": 500
  }'

# Health check
curl http://localhost:8000/health

Demo

Check out the demos in the scripts/demo/ folder:

# Download preprocessed data (required before running demos)
python scripts/demo/download.py

# Interactive streaming chat
python scripts/demo/stream_chat.py

# Batch generation
python scripts/demo/generate_batch.py

# Autoregressive generation
python scripts/demo/generate_ar.py

Watch a video walkthrough on bilibili.

Documentation

Document Description
Parameter Guide Training & inference parameters
Design Document Framework architecture & module design
Data Flow Data processing pipeline details
Model Introduction Model architecture & technical details

Contributing

We welcome contributions! Please see our Contributing Guidelines for details.

  1. Fork the repository.
  2. Create a feature branch.
  3. Commit your changes.
  4. Open a Pull Request.

For major changes, please open an issue first to discuss what you would like to change.

Community

License

This project is licensed under the GPL-3.0 License.


A lightweight Transformer framework designed for both high performance and ease of use.