Compare commits

..

No commits in common. "main" and "v1.3.3" have entirely different histories.
main ... v1.3.3

111 changed files with 4427 additions and 12488 deletions

View File

@ -2,7 +2,7 @@
name: Bug report name: Bug report
about: Create a report to help us improve about: Create a report to help us improve
title: "[BUG]" title: "[BUG]"
labels: bug labels: enhancement
assignees: '' assignees: ''
--- ---

View File

@ -16,9 +16,9 @@ Please delete options that are not relevant.
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
## Checklist: ## Checklist:
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`) - [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check --fix .`)
- [ ] I have performed a self-review of my own code - [ ] I have performed a self-review of my own code
- [ ] Code is self-documenting (no unnecessary comments) - [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation - [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings - [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added tests that prove my fix is effective or that my feature works

1
.gitignore vendored
View File

@ -15,7 +15,6 @@
!/.gitattributes !/.gitattributes
!/.dockerignore !/.dockerignore
!/Dockerfile !/Dockerfile
!/docker-compose.yml
!/assets/** !/assets/**
!/CONTRIBUTING.md !/CONTRIBUTING.md
!/LICENSE !/LICENSE

View File

@ -1,100 +1,68 @@
# Contributing to AstrAI # Contributing to AstrAI
Thank you for your interest in contributing! This document provides step-by-step guidelines. Thank you for your interest in contributing to AstrAI! This document provides guidelines and steps for contributing.
## Quick Start ## How to Contribute
### Reporting Issues
If you encounter a bug or have a feature request, please open an issue on GitHub. Include as much detail as possible:
- A clear description of the problem or request.
- Steps to reproduce (for bugs).
- Your environment (Python version, OS, etc.).
### Submitting Changes
1. **Fork** the repository.
2. **Clone** your fork:
```bash ```bash
git clone https://github.com/your-username/AstrAI.git git clone https://github.com/your-username/AstrAI.git
cd AstrAI cd AstrAI
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
``` ```
3. **Create a feature branch**:
```bash
git checkout -b feature/your-feature-name
```
4. **Make your changes**. Follow the code style guidelines below.
5. **Commit your changes** with a descriptive commit message:
```bash
git commit -m "Add: brief description of the change"
```
6. **Push** to your fork:
```bash
git push origin feature/your-feature-name
```
7. **Open a Pull Request** (PR) against the `main` branch of the upstream repository.
## Before You Commit ## Code Style
Run the following checks **in order** — CI will reject if any fail. AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
### 1. Format
- Run Ruff to format and lint:
```bash ```bash
ruff format . ruff format .
ruff check --fix .
``` ```
- The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`).
> **Note**: `ruff format` may rename parameters (e.g. `mask``attn_mask`). ## Testing
> Always review the diff after formatting.
### 2. Import sorting If you add or modify functionality, please include appropriate tests.
- Run the test suite with:
```bash ```bash
ruff check . --select I pytest
``` ```
- Ensure all tests pass before submitting your PR.
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
```bash
ruff check . --select I --fix .
ruff format . # re-format after fix
```
### 3. Run tests
```bash
python -u -m pytest tests/ -v
```
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
### 4. (Optional) Full pre-commit check
If you have Git Bash available:
```bash
bash scripts/pre_commit.sh
```
This runs format check, import sort check, and tests in one go.
## Commit Style
```
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
- bullet point body (each ~60 chars)
```
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
- **Subject line** ends with no period.
- **Body** uses bullet points starting with `-`.
- No `(scope)` parentheses.
## Common Issues
| Problem | Cause | Fix |
|---------|-------|-----|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
## Submitting Changes
1. Fork the repo.
2. Create a feature branch: `git checkout -b feat/my-feature`
3. Make changes following the steps above.
4. Commit with the commit style above.
5. Push: `git push origin feat/my-feature`
6. Open a Pull Request against `main`.
## Code Review ## Code Review
- All PRs are reviewed. We may request changes. All submissions will be reviewed. We may request changes or discuss alternatives. Please be responsive to feedback.
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
- Ensure all tests pass.
## License ## License
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE). By contributing, you agree that your contributions will be licensed under the same [GPL-3.0 License](LICENSE) that covers the project.
--- ---
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue. If you have any questions, feel free to ask in the [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
Happy contributing!

View File

@ -1,7 +1,7 @@
# AstrAI Dockerfile - Multi-stage Build (Optimized) # AstrAI Dockerfile - Multi-stage Build (Optimized)
# Build stage - use base image with minimal build tools # Build stage - use base image with minimal build tools
FROM ubuntu:24.04 AS builder FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS builder
WORKDIR /app WORKDIR /app
@ -18,7 +18,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN python3.12 -m venv --copies /opt/venv RUN python3.12 -m venv --copies /opt/venv
ENV PATH="/opt/venv/bin:$PATH" ENV PATH="/opt/venv/bin:$PATH"
# Copy source code and install (deps read from pyproject.toml) # Copy source code and install dependencies
COPY astrai/ ./astrai/ COPY astrai/ ./astrai/
COPY pyproject.toml . COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip \ RUN pip install --no-cache-dir --upgrade pip \
@ -26,14 +26,13 @@ RUN pip install --no-cache-dir --upgrade pip \
--extra-index-url https://download.pytorch.org/whl/cu126 --extra-index-url https://download.pytorch.org/whl/cu126
# Production stage # Production stage
FROM ubuntu:24.04 AS production FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS production
WORKDIR /app WORKDIR /app
# Install Python 3.12 runtime and healthcheck dependency # Install Python 3.12 runtime
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.12 \ python3.12 \
curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy virtual environment from builder # Copy virtual environment from builder

View File

@ -27,6 +27,9 @@
## 📖 Table of Contents ## 📖 Table of Contents
<details open>
<summary><b>English</b></summary>
- [Features](#features) - [Features](#features)
- [Quick Start](#quick-start) - [Quick Start](#quick-start)
- [Documentation](#documentation) - [Documentation](#documentation)
@ -34,6 +37,8 @@
- [Community](#community) - [Community](#community)
- [License](#license) - [License](#license)
</details>
--- ---
<a id="english"></a> <a id="english"></a>
@ -46,8 +51,7 @@
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos. - 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
- 📦 **Lightweight**: Minimal dependencies, easy to deploy. - 📦 **Lightweight**: Minimal dependencies, easy to deploy.
- 🔬 **ResearchFriendly**: Modular design, easy to experiment with new ideas. - 🔬 **ResearchFriendly**: Modular design, easy to experiment with new ideas.
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading. - 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets.
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
### Quick Start ### Quick Start
@ -65,52 +69,19 @@ For development dependencies:
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### Download Pre-trained Model
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
```bash
python scripts/demo/download.py
```
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
#### Train a Model #### Train a Model
```bash ```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/param_path
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
``` ```
Full reference at [Parameter Guide](assets/docs/params.md).
#### Generate Text #### Generate Text
```bash ```bash
python scripts/tools/generate.py \ python scripts/tools/generate.py --param_path=/path/to/param_path
--param_path /path/to/model \
--input_json_file /path/to/input.jsonl \
--output_json_file /path/to/output.jsonl
``` ```
#### Docker #### Docker
@ -133,19 +104,13 @@ docker run --gpus all -p 8000:8000 astrai:latest \
# Run with volume mount for data # Run with volume mount for data
docker run --gpus all -v /path/to/data:/data -it astrai:latest docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker Compose (GPU, default)
docker compose up -d
# Docker Compose (CPU only)
docker compose --profile cpu up -d
``` ```
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`. > **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
#### Start HTTP Server #### Start HTTP Server
Start the inference server with OpenAI and Anthropic-compatible HTTP API: Start the inference server with OpenAI-compatible HTTP API:
```bash ```bash
python -m scripts.tools.server --port 8000 --device cuda python -m scripts.tools.server --port 8000 --device cuda
@ -154,7 +119,7 @@ python -m scripts.tools.server --port 8000 --device cuda
Make requests: Make requests:
```bash ```bash
# OpenAI-compatible # Chat API (OpenAI compatible)
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -162,7 +127,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 512 "max_tokens": 512
}' }'
# OpenAI-compatible streaming # Streaming response
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -171,27 +136,6 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 500 "max_tokens": 500
}' }'
# Anthropic-compatible
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# Anthropic-compatible streaming with stop sequences
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["The end"]
}'
# Health check # Health check
curl http://localhost:8000/health curl http://localhost:8000/health
``` ```
@ -214,18 +158,16 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py python scripts/demo/generate_ar.py
``` ```
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6). Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd).
### Documentation ### Documentation
| Document | Description | | Document | Description |
|----------|-------------| |----------|-------------|
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters | | [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns | | [Design Document](./assets/docs/design.md) | Framework architecture & module design |
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas | | [Data Flow](./assets/docs/dataflow.md) | Data processing pipeline details |
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API | | [Model Introduction](./assets/docs/introduction.md) | Model architecture & technical details |
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
### Contributing ### Contributing

View File

@ -52,8 +52,7 @@
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。 - 💡 **易用**: 简洁的 API 与丰富的示例、演示。
- 📦 **轻量**: 依赖少,部署简单。 - 📦 **轻量**: 依赖少,部署简单。
- 🔬 **研究友好**: 模块化设计,便于实验新想法。 - 🔬 **研究友好**: 模块化设计,便于实验新想法。
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。 - 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API开箱即用。
### 快速开始 ### 快速开始
@ -71,52 +70,19 @@ pip install -e .
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### 下载预训练模型
下载预训练模型权重1B 双语检查点)到 `params/` 目录:
```bash
python scripts/demo/download.py
```
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`
#### 训练模型 #### 训练模型
```bash ```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/param_path
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
``` ```
完整参数列表见[参数说明](./params.md)。
#### 文本生成 #### 文本生成
```bash ```bash
python scripts/tools/generate.py \ python scripts/tools/generate.py --param_path=/path/to/param_path
--param_path /path/to/model \
--input_json_file /path/to/input.jsonl \
--output_json_file /path/to/output.jsonl
``` ```
#### Docker #### Docker
@ -139,19 +105,13 @@ docker run --gpus all -p 8000:8000 astrai:latest \
# 挂载数据卷 # 挂载数据卷
docker run --gpus all -v /path/to/data:/data -it astrai:latest docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker ComposeGPU默认
docker compose up -d
# Docker Compose仅 CPU
docker compose --profile cpu up -d
``` ```
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False` > **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`
#### 启动 HTTP 服务 #### 启动 HTTP 服务
启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API 启动推理服务器,支持 OpenAI 兼容的 HTTP API
```bash ```bash
python -m scripts.tools.server --port 8000 --device cuda python -m scripts.tools.server --port 8000 --device cuda
@ -160,7 +120,7 @@ python -m scripts.tools.server --port 8000 --device cuda
发起请求: 发起请求:
```bash ```bash
# OpenAI 兼容 # Chat APIOpenAI 兼容
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -168,7 +128,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 512 "max_tokens": 512
}' }'
# OpenAI 兼容流式 # 流式响应
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -177,27 +137,6 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 500 "max_tokens": 500
}' }'
# Anthropic 兼容
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "你是一个乐于助人的助手。",
"messages": [{"role": "user", "content": "你好"}],
"max_tokens": 512
}'
# Anthropic 兼容流式并设置停止序列
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "写个故事"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["结束"]
}'
# 健康检查 # 健康检查
curl http://localhost:8000/health curl http://localhost:8000/health
``` ```
@ -220,18 +159,16 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py python scripts/demo/generate_ar.py
``` ```
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。 观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。
### 文档 ### 文档
| 文档 | 说明 | | 文档 | 说明 |
|------|------| |------|------|
| [参数说明](./params.md) | 训练与推理参数配置 | | [参数说明](./params.md) | 训练与推理参数配置 |
| [架构文档](./architecture.md) | 系统架构、类图与设计模式 | | [设计文档](./design.md) | 系统架构与模块设计 |
| [训练文档](./training.md) | 训练循环、策略与公式 | | [数据流程](./dataflow.md) | 数据处理管道详解 |
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API | | [模型介绍](./introduction.md) | 模型架构与技术细节 |
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
### 贡献 ### 贡献

File diff suppressed because it is too large Load Diff

View File

@ -1,64 +1,269 @@
# Data Flow # AstrAI Data Flow Documentation
This document describes the data pipeline: from raw text to model input tensors. This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
## Overview ## Overview
``` AstrAI adopts a modular design with the following main components:
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
## Data Flow Diagram
```mermaid
flowchart LR
subgraph A[Data Preparation]
direction TB
A1[Raw Text] --> A2[AutoTokenizer]
A2 --> A3[Serialize to .h5 files]
A3 --> A4[BaseDataset]
A4 --> A5[ResumableDistributedSampler]
A5 --> A6[PyTorch DataLoader]
end
subgraph B[Training]
direction TB
B1[Batch Data] --> B2[TrainContextBuilder]
B2 --> B3[TrainContext]
B3 --> B4[BaseStrategy]
B4 --> B5[Transformer]
B5 --> B6[Compute Loss]
B6 --> B7[Backward]
B7 --> B8[Optimizer]
B8 --> B9[LRScheduler]
B9 --> B10[CheckpointCallback]
end
subgraph C[Inference]
direction TB
C1[Checkpoint] --> C2[AutoModel]
C2 --> C3[Transformer + Tokenizer]
C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward]
C8 --> C9[KV Cache + Prefix Cache]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
end
A --> B
B --> C
``` ```
## Data Preparation ## Detailed Module Descriptions
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups. ### 1. Dataset Module
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: #### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
``` #### 1.2 Dataset (`dataset.py`)
StoreFactory.create("h5") → H5Store - **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
StoreFactory.create("bin") → MmapStore - **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
``` - **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. #### 1.3 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options
## Data Keys by Training Type ### 2. Model Module
| Type | Storage Keys | #### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|------|-------------| - **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) | - **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
| `sft` | `sequence`, `loss_mask` | - Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` | - Supports weight tying (`tie_weight=True`) to reduce parameter count
| `grpo` | `prompts`, `responses`, `masks`, `rewards` | - Uses Rotary Position Embedding (RoPE) to inject position information
- Supports loading from safetensors format with automatic model type detection from `config.json`
## Dataset Architecture #### 2.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation
- **`MLA`**: Multi-Latent Attention implementation (like Qwen2-VL)
- **`MLP`**: Feed-forward network with SiLU activation and gated mechanism
- **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
``` ### 3. Training Module
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
→ BaseDataset.load(load_path, storage_type=None)
→ detect_format(load_path)
→ StoreFactory.create(storage_type)
→ Store.load(load_path)
→ H5Store._normalize() / MmapStore._normalize()
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx)
→ get_index(idx) → [begin, end)
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
```
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`). #### 3.1 Training Context (`train_context.py`)
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`. #### 3.2 Trainer (`trainer.py`)
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
- Training steps include:
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
## Sampler #### 3.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
- **`DPOStrategy`**: Direct Preference Optimization
- **`GRPOStrategy`**: Group Relative Policy Optimization
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration
`ResumableDistributedSampler` supports checkpoint-aware distributed sampling: #### 3.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer
- Tracks `start_epoch` / `start_iter` for resume #### 3.5 Callbacks (`train_callback.py`)
- Shuffle via `torch.Generator(seed + epoch)` - **`TrainCallback`**: Protocol interface for trainer callbacks
- Per-replica index slicing for DDP - **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress
- **`MetricLoggerCallback`**: Logs training metrics to JSON files
- **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler
## DataLoader ### 4. Factory Module
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. #### 4.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering
> Document Update Time: 2026-05-30 ### 5. Parallel Module
#### 5.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
- **`get_rank`**: Returns current process rank in distributed group
- **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment
#### 5.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 6. Inference Module
#### 6.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility
#### 6.2 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 6.3 Server (`server.py`)
- FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints
- Supports both streaming and non-streaming responses
### 7. Tokenizer Module
#### 7.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 7.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts
## Training Data Flow - Detailed Steps
1. **Data Preparation**
- Raw text is converted to token ID sequences through AutoTokenizer
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
- Files can contain multiple segments, each segment corresponds to a tensor
2. **Dataset Loading**
- `BaseDataset`'s `load` method calls `load_h5`, obtaining `segments` dictionary
- Create `MultiSegmentFetcher` to manage data for multiple keys
- Calculate total sample count, and determine start/end indices for each sample based on window size and stride
3. **Sampling and Batch Loading**
- `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position
- PyTorch `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
4. **Strategy Forward and Loss Calculation**
- Batch data is passed to strategy (such as `SEQStrategy`)
- Strategy internally calls `Transformer` model, obtaining logits
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
- Return loss tensor
5. **Backpropagation and Optimization**
- Loss is normalized by dividing by accumulation steps, then `loss.backward()` is executed
- After accumulating `accumulation_steps` batches, optimizer `step()` and `zero_grad()` are executed
- Learning rate scheduler updates learning rate after each step
6. **Checkpoint Saving**
- `CheckpointCallback` saves checkpoints at set intervals
- Checkpoints contain model state dict, current epoch, iteration, and other metadata
- Saved in safetensors format, ensuring safety and efficiency
## Inference Data Flow - Detailed Steps
1. **Model Loading**
- Load `Transformer` model from checkpoint via `AutoModel.from_pretrained()`
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
2. **Prompt Construction and Encoding**
- User messages (list of dict with role and content) are converted to ChatML format string through `apply_chat_template` method in tokenizer
- Tokenizer encodes prompt string to token ID sequence `input_ids`
- For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop**
- Initialize KV cache (optional) and prefix cache
- Loop until generating `max_len` tokens or encountering stop token:
- Input current `input_ids` (or cached new token) to model, obtain `logits`
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
- Sample next token ID from the processed distribution
- Append new token to `input_ids`, while updating KV cache
- For streaming generation, yield each token to caller immediately
4. **Decoding and Output**
- Decode generated token ID sequence to text through tokenizer
- Remove special tokens, return plain text response
## Checkpoint and Serialization
- **Training Checkpoint**: Saves model parameters, optimizer state, scheduler state, current epoch and iteration
- **Model Parameters**: Supports safetensors format, automatically handles special logic like weight tying during loading
- **Dataset Serialization**: HDF5 format supports efficient random access and shared memory, suitable for large-scale pre-training data
## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-09

694
assets/docs/design.md Normal file
View File

@ -0,0 +1,694 @@
## 1. Why I Created This Project
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, and the training code is open source!
## 2. System Architecture
```mermaid
classDiagram
namespace config {
class ModelConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+bool tie_weight
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+load(config_path) ModelConfig
+save(config_path)
}
class TrainConfig {
+nn.Module model
+str strategy
+Dataset dataset
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_size
+int accumulation_steps
+float max_grad_norm
+int start_epoch
+int start_batch
+str ckpt_dir
+int ckpt_interval
+int random_seed
+int num_workers
+int prefetch_factor
+bool pin_memory
+int nprocs
+str backend
+str master_addr
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+List[int] device_ids
+str device_type
+dict extra_kwargs
+validate()
}
}
namespace dataset {
class BaseDataset {
+int window_size
+int stride
+MultiSegmentFetcher fetcher
+load(load_path)
+__getitem__(index)
+__len__()
}
class SEQDataset {
+__getitem__(index) Dict
}
class SFTDataset {
+__getitem__(index) Dict
}
class DPODataset {
+__getitem__(index) Dict
}
class GRPODataset {
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List~Tensor~ segments
+List~int~ cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
+int start_epoch
+int start_iter
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
class Checkpoint {
+dict state_dict
+int epoch
+int iteration
+save(save_dir)
+load(save_dir) Checkpoint
}
}
namespace model {
class AutoModel {
+ModelConfig config
+Dict _registry
+register(model_type) decorator
+get_model_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
+load_state_dict(state_dict)
+state_dict()
}
class DecoderBlock {
+GQA attention
+RMSNorm input_norm
+MLP mlp
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
}
class GQA {
+int n_heads
+int n_kv_heads
+int head_dim
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
}
class MLA {
+int n_heads
+int n_kv_heads
+int head_dim
+Linear q_a_proj, q_b_proj, q_c_proj
+Linear kv_a_proj, kv_b_proj, kv_c_proj
+Linear o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
}
class MLP {
+Linear up, gate, down
+forward(x) Tensor
}
class RMSNorm {
+Parameter weight
+float norm_eps
+forward(x) Tensor
}
class Linear {
+Parameter weight
+Parameter bias
+forward(x) Tensor
}
class RotaryEmbedding {
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple~Tensor, Tensor~
}
class Embedding {
+Parameter weight
+forward(x) Tensor
}
}
namespace tokenize {
class AutoTokenizer {
+List~str~ stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union~str, List[int]~
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
+save_pretrained(save_path)
}
class ChatTemplate {
+String template_str
+render(messages, add_generation_prompt) str
+from_string(template) ChatTemplate
}
}
namespace factory {
class Registry {
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List~str~
}
class BaseFactory {
+Registry _registry
+register(name, category, priority) decorator
+create(name, *args, **kwargs) T
+list_registered() list
}
}
namespace trainer {
class Trainer {
+TrainConfig train_config
+List~TrainCallback~ callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List~TrainCallback~
}
class TrainContext {
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+int epoch
+int iteration
+float loss
+int world_size
+int rank
}
class TrainContextBuilder {
+TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder
+with_dataloader() TrainContextBuilder
+with_strategy() TrainContextBuilder
+build() TrainContext
}
class BaseStrategy {
+nn.Module model
+str device
+compute_loss(batch) Tensor
}
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(model, train_type, device, **kwargs) BaseStrategy
}
class SEQStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class SFTStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class DPOStrategy {
+nn.Module ref_model
+float beta
+str reduction
+compute_loss(batch) Tensor
}
class GRPOStrategy {
+nn.Module ref_model
+float clip_eps
+float kl_coef
+int group_size
+compute_loss(batch) Tensor
}
class BaseScheduler {
+get_lr() List~float~
+step()
}
class SchedulerFactory {
+Registry _registry
+register(name) decorator
+create(optimizer, schedule_type, **kwargs) BaseScheduler
}
class CosineScheduler {
+int warmup_steps
+int lr_decay_steps
+float min_rate
}
class SGDRScheduler {
+int warmup_steps
+int cycle_length
+float min_rate
+int t_mult
}
class TrainCallback {
+on_train_begin(context)
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
}
class SchedulerCallback {
+on_train_begin(context)
+on_batch_end(context)
}
class CheckpointCallback {
+str save_dir
+int interval
+_save_checkpoint(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
}
class ProgressBarCallback {
+int num_epoch
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
}
class MetricLoggerCallback {
+str log_dir
+int save_interval
+on_batch_end(context)
+on_train_end(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
}
namespace inference {
class InferenceEngine {
+nn.Module model
+AutoTokenizer tokenizer
+InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+int max_prefix_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict
+shutdown()
}
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+ModelConfig config
+Tuple kv_cache
+Tensor seq_mask
+PrefixCacheManager prefix_cache
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
+stop()
+get_stats() Dict
}
class PrefixCacheManager {
+RadixNode root
+int max_capacity
+List lru
+insert(token_ids, slot)
+find_longest_prefix(token_ids) Tuple[int, int]
+release(token_ids)
}
class RadixNode {
+Dict children
+int hash
+int slot
+int ref_count
+float last_access
+List token_sequence
}
class Task {
+str task_id
+List prompt_ids
+int max_tokens
+float temperature
+float top_p
+int top_k
+TaskStatus status
+List output_ids
+int input_tokens
+int output_tokens
+int slot
+Callable stream_callback
+is_finished(stop_ids) bool
}
class TaskStatus {
+str PENDING
+str RUNNING
+str FINISHED
+str ABORTED
}
class Server {
+start()
+predict(request)
}
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+List~Dict~ messages
+stream bool
}
class _Result {
+List~str~ tokens
+List~str~ results
+List~bool~ done_flags
+append(token, idx)
+get_results() List~str~
}
class ChatMessage {
+str role
+str content
}
class ChatCompletionRequest {
+List~ChatMessage~ messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional~str~ system_prompt
}
class CompletionResponse {
+str id
+str object
+int created
+str model
+List~Dict~ choices
}
}
namespace parallel {
class ParallelSetup {
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
}
class ParallelModel {
+dist.ProcessGroup process_group
+int rank
+int world_size
}
class ColumnParallelLinear {
+forward(x) Tensor
}
class RowParallelLinear {
+forward(x) Tensor
}
}
%% Relationships
TrainConfig --> ModelConfig : uses
TrainConfig --> BaseDataset : uses
TrainConfig --> StrategyFactory : selects
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- SchedulerCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses
InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses
Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseSegmentFetcher --> MultiSegmentFetcher : used by
MultiSegmentFetcher --> BaseDataset : used by
AutoModel <|-- Transformer
AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses
Transformer --> RotaryEmbedding : uses
Transformer --> Embedding : uses
DecoderBlock --> GQA : uses
DecoderBlock --> MLA : uses
DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses
TrainContextBuilder --> ResumableDistributedSampler : creates
ResumableDistributedSampler --> BaseDataset : samples
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
AutoModel ..> AutoTokenizer : loads with
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
```
### Module Overview
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
### Design Patterns
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
## 3. Training Process
The common training process for large language models (LLM) typically includes three stages: **Pre-training (SEQ)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (DPO/GRPO)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies.
### Core Formulas
**Pre-training (SEQ):**
$$
L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
**SFT:**
$$
L_{\text{SFT}} = - \sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$
**DPO:**
$$
L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]
$$
**GRPO:**
GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective:
$$
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
$$
Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards.
$$
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
$$
In this implementation, an off-policy approach is used ($\pi_\theta = \pi_{\text{ref}}$), and the policy loss simplifies to:
$$
L_{\text{policy}} = -\mathbb{E}[A]
$$
The KL divergence term uses mean squared error approximation:
$$
L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right]
$$
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-04-09

View File

@ -1,152 +0,0 @@
# Inference
## KV Cache
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
$$
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
$$
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
## KVCache System
Six classes (plus two helpers) working together:
```
KVCache (facade)
├── PagePool orchestrates page allocation + prefix matching
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
## Continuous Batching
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
```
1. Cleanup → Remove finished tasks, free KV pages
2. Refill → Pop from waiting_queue, task_alloc pages, activate
3. Prefill → Group by (prompt_len, start_pos), run full forward
4. Decode → Pick largest same-position group, single-token forward
```
## Sampling (Strategy Pattern)
```
BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
├── TopPStrategy
└── SamplingPipeline
```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Strategy Pattern)
```python
class ProtocolHandler: # concrete orchestrator
def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops)
else: return await self._handle_non_stream(agen, ctx, stops)
```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult
```
InferenceEngine
├── generate(prompt, stream, ...) → str | List[str] | Generator
├── generate_with_request(req) → same
├── generate_async(prompt, ...) → AsyncGenerator
├── get_stats() → Dict
└── shutdown()
```
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
## HTTP API
```
POST /v1/chat/completions OpenAI
POST /v1/messages Anthropic
GET /health {"status":"ok","model_loaded":true}
GET /stats scheduler statistics
```
### OpenAI
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Response:
```json
{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
}
```
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
### Anthropic
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Supports `stop_sequences` and streaming via `event: content_block_delta`.
### GenerationRequest Parameters
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `top_k` | int | 50 | Top-k count |
| `top_p` | float | 1.0 | Nucleus threshold |
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
| `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output |
## Engine API
```python
# Non-streaming
engine.generate("Hello", stream=False) # -> str
engine.generate(["A", "B"], stream=False) # -> List[str]
# Streaming
engine.generate("Hello", stream=True) # -> Generator[str]
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
# Async
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
print(token)
```
> Document Update Time: 2026-05-30

299
assets/docs/introduction.md Normal file
View File

@ -0,0 +1,299 @@
## Model Introduction
### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving:
```python
from astrai.model import AutoModel
# Load model from checkpoint
model = AutoModel.from_pretrained("path/to/model")
# Save model to new directory
model.save_pretrained("path/to/save")
```
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types.
```mermaid
flowchart TB
subgraph Layers["Transformer Layers"]
direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer 32]
D --> E[RMSNorm]
E --> F[Linear]
F --> G[SoftMax]
end
subgraph TransformerBlock["Transformer Block"]
direction TB
H[x] --> I[RMSNorm]
I --> J[Linear → Q/K/V]
J --> K[Q]
J --> L[K]
J --> M[V]
K --> N[RoPE]
L --> O[RoPE]
N --> P["Q @ K^T / sqrt(d)"]
O --> P
P --> Q[Masked SoftMax]
Q --> R[S @ V]
M --> R
R --> S[Linear]
S --> T[+]
H --> T
T --> U[RMSNorm]
U --> V[Linear]
V --> W[SiLU]
V --> X[×]
W --> X
X --> Y[Linear]
Y --> Z[+]
T --> Z
Z --> AA[x']
end
classDef main fill:#e6f3ff,stroke:#0066cc;
classDef block fill:#fff2e6,stroke:#cc6600;
class Layers main;
class TransformerBlock block;
```
What is an autoregressive model? After splitting a sentence into tokens, the model predicts the probability distribution of the next token. This means the model calculates the probability of the next possible token and its corresponding probability based on the given context (the sequence of tokens that have already appeared).
#### 1. Autoregression
In autoregressive modeling, when a sentence is tokenized into a sequence of tokens, the model learns to predict what comes next. Given a sequence of tokens as input, the model calculates a probability distribution over all possible next tokens. This distribution tells us how likely each potential next token is, given the current context.
For instance, if the input sequence contains tokens representing a question, the model might predict that certain response tokens have higher probabilities than others. The sampling process then selects one token from this distribution—controlled by parameters like top_k, top_p, and temperature—to serve as the next token in the sequence.
Once a token is selected, it is appended to the input sequence, and the model repeats this process. The updated sequence is then fed back into the model to predict the next token. This iterative process continues until either a special end-of-sequence token is generated, or the maximum sequence length is reached. These control tokens are essential because without them, the model would continue generating tokens indefinitely, eventually exhausting available memory.
#### 2. Causal Mask
Transformers use attention mechanism. The input shape is generally [bsz, seq_len], and the output is [bsz, seq_len, n_dim]. To predict the next token, the model's input and output must be offset by one position. The target predicted by the model must be offset by one position, and during training we also use the offset-by-one method:
```
sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
The attention score calculation formula is:
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
$$ s_{ij} := s_{ij} + mask_{ij} $$
Here, the attention score represents the degree to which the model attends to the similarity between two tokens.
For decoder-only structure models, to prevent the model from "stealing" information from future positions, a mask needs to be added during attention calculation. We need to apply a mask before attention score calculation. This mask is typically a lower triangular matrix, and for a sequence of length n, its shape is [n, n]. Below is an example of how to create such a causal mask matrix for a sequence of length 5:
```
[[0, -inf, -inf, -inf, -inf],
[0, 0, -inf, -inf, -inf],
[0, 0, 0, -inf, -inf],
[0, 0, 0, 0, -inf],
[0, 0, 0, 0, 0]]
```
In this matrix, 0 represents positions that can be attended to, while -inf represents positions that should be masked (i.e., should not be attended to). Because this matrix ensures that after the softmax, the parts of the attention scores where $j > i$ change from `inf` to 0, meaning the model cannot see future information.
#### 3. Rotary Position Embedding
Rotary Position Embedding (RoPE) is a position encoding method designed to solve the problem of lacking direct modeling of sequence position information in Transformer models. Unlike traditional position encodings (such as sine and cosine function position encodings), RoPE embeds position information directly into the Query (Q) and Key (K) vectors, allowing the model to more naturally handle relative position relationships in sequences.
$$ q_i = R_i W_q x_i $$
$$ k_j = R_j W_k x_j $$
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
The $R_{i-j}$ controls the attenuation of attention for different tokens at different relative distances. When the absolute value of $i - j$ is larger, the degree of attenuation is stronger. This approach allows the model to learn relative position relationships, enabling the model to scale and adapt to longer sequences.
## KV Cache Implementation
According to the attention calculation formula:
$$
\begin{align*}
o_i &= \sum_j s_{ij} v_{j} \newline
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
\end{align*}
$$
Since the model is an autoregressive model, we only need to calculate for the last part of the sequence, meaning the index $i$ is fixed as the last element of the sequence, and we compute $o_{n}$:
$$
\begin{align*}
o_n &= \sum_j s_{j}v_{j} \newline
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
\end{align*}
$$
If we expand the expression:
$$
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
$$
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
### 4. AutoModel Loading
The project now uses the **AutoModel** base class for flexible model loading and saving:
```python
from astrai.model import AutoModel
# Load model from checkpoint
model = AutoModel.from_pretrained("path/to/model")
# Save model to new directory
model.save_pretrained("path/to/save")
```
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights.
### 5. Continuous Batching Inference
The inference engine supports **continuous batching** for efficient batch processing:
```python
from astrai.inference import InferenceEngine, GenerationRequest
# Create inference engine with continuous batching
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=8,
max_seq_len=4096,
)
# Use GenerationRequest with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=1024,
stream=True,
)
# Generate with streaming
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
```
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
## HTTP API Usage
The inference server provides HTTP endpoints for remote inference. Start the server first:
```bash
python -m scripts.tools.server --port 8000
```
### OpenAI-Compatible Endpoint
The server provides an OpenAI-compatible chat completion endpoint at `/v1/chat/completions`:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
],
"temperature": 0.8,
"max_tokens": 2048,
"stream": false
}'
```
**Request Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `messages` | List[dict] | Required | Chat messages with role and content |
| `temperature` | float | 0.8 | Sampling temperature (0.0-2.0) |
| `top_p` | float | 0.95 | Nucleus sampling threshold |
| `top_k` | int | 50 | Top-k sampling parameter |
| `max_tokens` | int | 2048 | Maximum tokens to generate |
| `stream` | bool | false | Enable streaming response |
| `system_prompt` | str | None | System prompt override |
**Response (non-streaming):**
```json
{
"id": "chatcmpl-1234567890",
"object": "chat.completion",
"created": 1234567890,
"model": "astrai",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
"finish_reason": "stop"
}
]
}
```
### Streaming Response
Enable streaming for real-time token-by-token output:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Write a story"}],
"stream": true,
"max_tokens": 500
}'
```
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
### Simple Generation Endpoint
For basic text generation without chat format:
```bash
curl -X POST "http://localhost:8000/generate?query=Hello&max_len=1000" \
-H "Content-Type: application/json"
```
Or with conversation history:
```bash
curl -X POST "http://localhost:8000/generate" \
-H "Content-Type: application/json" \
-d '{
"query": "What is AI?",
"history": [["Hello", "Hi there!"], ["How are you?", "I'm doing well"]],
"temperature": 0.8,
"max_len": 2048
}'
```
### Health Check
Monitor server and model status:
```bash
curl http://localhost:8000/health
# {"status": "ok", "model_loaded": true, "engine_ready": true}
curl http://localhost:8000/stats
# {"requests_total": 10, "tokens_generated": 5000, ...}
```
> Document Update Time: 2026-04-09

View File

@ -4,97 +4,138 @@
### Basic Parameters ### Basic Parameters
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required | | `--train_type` | Training type (seq, sft, dpo, grpo) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| `--data_root_path` | Dataset root directory | required | | `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required | | `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 | | `--n_epoch` | Total training epochs | 1 |
| `--batch_per_device` | Batch size per device | 1 | | `--batch_size` | Batch size | 4 |
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 | | `--accumulation_steps` | Gradient accumulation steps | 1 |
### Learning Rate Scheduling ### Learning Rate Scheduling
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 | | `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 | | `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 | | `--max_grad_norm` | Maximum gradient norm | 1.0 |
### Optimizer (AdamW) ### Checkpoint
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--resume_dir` | Resume training from specified path | - |
### Optimizer Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--adamw_beta1` | AdamW beta1 | 0.9 | | `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 | | `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 | | `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading ### Data Loading
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--window_size` | Max input sequence length | model config `max_len` | | `--random_seed` | Random seed | 3407 |
| `--stride` | Stride for sliding window over sequences | None | | `--num_workers` | DataLoader workers | 0 |
| `--random_seed` | Random seed for reproducibility | 3407 | | `--prefetch_factor` | Prefetch factor for dataloader | None |
| `--num_workers` | DataLoader worker processes | 4 | | `--pin_memory` | Enable pin_memory | False |
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) | | `--no_pin_memory` | Disable pin_memory | - |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training ### Distributed Training
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--nprocs` | Number of GPUs / processes | 1 | | `--nprocs` | Number of GPUs | 1 |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none | | `--device_type` | Device type (cuda/cpu) | cuda |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
### Strategy-specific ### Other Parameters
| Parameter | Description | Default | Used by | | Parameter | Description | Default Value |
|-----------|-------------|---------|---------| |-----------|-------------|---------------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` | | `--window_size` | Maximum input sequence length | model config max_len |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` | | `--stride` | Input sequence stride | - |
| `--group_size` | GRPO group size | 4 | `grpo` | | `--dpo_beta` | DPO beta value | 0.1 |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` | | `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` | | `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` | | `--grpo_group_size` | GRPO group size | 4 |
| `--label_smoothing` | Label smoothing parameter | 0.1 |
### Usage Example | `--start_epoch` | Starting epoch | 0 |
| `--start_batch` | Starting batch | 0 |
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
--- ---
> Document Update Time: 2026-05-24 ## Generation Parameters
### GenerationRequest Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `messages` | List of message dictionaries (role, content) | required |
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_len` | Maximum generation length | 1024 |
| `stream` | Whether to stream output | False |
### Usage Example
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import Tokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = Tokenizer("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
# Build request with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=1024,
)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
# Or use simple generate interface
result = engine.generate(
prompt="Hello",
stream=False,
max_tokens=1024,
temperature=0.8,
top_p=0.95,
top_k=50,
)
```
### Generation Modes
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-04-09

View File

@ -1,346 +0,0 @@
# Preprocessing Pipeline
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
## Philosophy
| Component | Responsibility |
|-----------|---------------|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
A single config file captures the entire pipeline, reusable and version-controllable.
## Config Structure
```json
{
"input": {}, // sections (single) or sources (multi)
"mask": {}, // role → "train" | "mask"
"mask_default": "mask",
"preprocessing": {},
"output": {}
}
```
### Section Fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `field` | str | -- | JSONL key to read |
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
| `template` | bool | `false` | Apply `chat_template` per message |
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
### Source Fields (multi-output mode)
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] | -- | Same as single-output section list |
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
---
## Quick Start
### SFT Chat
Input JSONL:
```json
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
```
Config:
```json
{
"input": {
"sections": [
{"field": "messages", "action": "$role", "template": true}
]
},
"mask": {
"system": "mask",
"user": "mask",
"assistant": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
},
"output": {
"storage_format": "bin",
"dtype": {"loss_mask": "bool"}
}
}
```
Output keys: `sequence` (int32), `loss_mask` (bool)
### SFT Instruction
Input JSONL:
```json
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
```
Config:
```json
{
"input": {
"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
}
}
```
Output keys: `sequence`, `loss_mask`
### Pretrain
Input JSONL:
```json
{"text": "Artificial Intelligence is a field of computer science..."}
```
Config:
```json
{
"input": {
"sections": [
{"field": "text", "action": "train"}
]
},
"preprocessing": {
"max_seq_len": 8192,
"min_chars": 100
}
}
```
Output keys: `sequence` (no `loss_mask` — all tokens trained)
### DPO
Input JSONL:
```json
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
```
Config:
```json
{
"input": {
"sources": {
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": true}
]
},
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": true}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
### GRPO
Input JSONL:
```json
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
```
Config:
```json
{
"input": {
"sources": {
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": true}
]
},
"responses": {
"sections": [
{"field": "responses", "action": "train"}
],
"list_field": true,
"mask_key": "masks"
},
"rewards": {
"sections": [
{"field": "rewards", "action": "value"}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
- `action: "value"` — extract raw values from JSONL without tokenisation
- `list_field: true` — tokenise each list element independently, then concatenate
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
---
## Configuration Reference
### `input`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
When `sources` is set, `sections` is ignored.
### `mask`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
### `preprocessing`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
| `max_items` | int or null | `null` | Stop after N documents |
### `output`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
---
## Mask Algorithm
### Template mode (`template: true`)
For each message in the field's array:
1. Prepend BOS token (masked)
2. Render through `chat_template` for that single message
3. Encode rendered text
4. Apply mask rule for the message's role
### Non-template mode
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
### Text config detection
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
---
## Output Layout
### Single-Shard (`bin`)
```
output/
__default__/
meta.json
sequence.bin
loss_mask.bin
wiki/
meta.json
sequence.bin
loss_mask.bin
```
### Multi-Shard (`bin`)
When `max_tokens_per_shard` is exceeded:
```
output/
__default__/
shard_0000/
meta.json
sequence.bin
loss_mask.bin
shard_0001/
meta.json
sequence.bin
loss_mask.bin
```
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
---
## CLI
```bash
# SFT
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
# DPO
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
# GRPO
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
```
---
## Python API
```python
from astrai.preprocessing.pipeline import Pipeline
from astrai.config.preprocess_config import PipelineConfig
config = PipelineConfig.from_json("sft.json")
Pipeline(
config,
["data_part1.jsonl", "data_part2.jsonl"],
output_dir="output/",
tokenizer_path="params",
).run()
```
> Document Update Time: 2026-06-03

View File

@ -1,201 +0,0 @@
# Training
### Autoregression
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
### Causal Mask
```
sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
Lower-triangular mask prevents attending to future positions:
```
[[0, -inf, -inf, -inf, -inf],
[0, 0, -inf, -inf, -inf],
[0, 0, 0, -inf, -inf],
[0, 0, 0, 0, -inf],
[0, 0, 0, 0, 0]]
```
### Rotary Position Embedding (RoPE)
RoPE embeds position into Q/K vectors via complex rotation:
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
## Training Loop
Two-level loop: **epoch****batch**. Optimizer step fires every `grad_accum_steps` batches.
```
on_train_begin
model.train()
on_epoch_begin
for batch in dataloader:
on_batch_begin
with executor.accumulate(model):
loss = strategy.compute_loss(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
on_batch_end
if executor.sync_gradients:
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
if scheduler:
scheduler.step()
on_epoch_end
on_train_end
```
### Callback Lifecycle
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
| `on_batch_begin` | Every batch | — |
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies
### SEQ (Pre-training)
Next-token cross-entropy with optional label smoothing:
$$
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
### SFT (Supervised Fine-Tuning)
Masked cross-entropy (`ignore_index=-100`) over response tokens:
$$
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
### DPO (Direct Preference Optimization)
Frozen reference model, preference margin via log-ratio:
$$
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
$$
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
### GRPO (Group Relative Policy Optimization)
On-policy PPO with group-normalized advantages:
$$
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
$$
$$
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
$$
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
Keys: `prompts`, `responses`, `masks`, `rewards`.
## LR Schedulers
| Type | Class | Description |
|------|-------|-------------|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
## Gradient Checkpointing
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
```python
from astrai.model.components.decoder_block import DecoderBlock
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
```
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern)
```python
context = (
TrainContextBuilder(config)
.with_resume_dir(resume_dir)
.build()
)
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
```
- Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
## Training CLI
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-30

View File

@ -1,9 +1,8 @@
__version__ = "1.3.7" __version__ = "1.3.3"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (
AutoRegressiveLMConfig, ModelConfig,
EncoderConfig,
TrainConfig, TrainConfig,
) )
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
@ -12,14 +11,13 @@ from astrai.inference import (
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
from astrai.model import AutoModel, AutoRegressiveLM from astrai.model import AutoModel, Transformer
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [ __all__ = [
"AutoRegressiveLM", "Transformer",
"AutoRegressiveLMConfig", "ModelConfig",
"EncoderConfig",
"TrainConfig", "TrainConfig",
"DatasetFactory", "DatasetFactory",
"AutoTokenizer", "AutoTokenizer",

View File

@ -1,25 +1,8 @@
from astrai.config.model_config import ( from astrai.config.model_config import ModelConfig
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
)
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
__all__ = [ __all__ = [
"BaseModelConfig", # Model configuration
"AutoRegressiveLMConfig", "ModelConfig",
"EncoderConfig",
"ConfigFactory",
"TrainConfig", "TrainConfig",
"InputConfig",
"OutputConfig",
"PipelineConfig",
"ProcessingConfig",
] ]

View File

@ -1,98 +0,0 @@
import json
from dataclasses import MISSING, dataclass, fields
from pathlib import Path
from typing import Any, Dict, Optional, Self, Union, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, (dict, list, tuple)):
try:
val = list(v) if isinstance(v, tuple) else v
json.dumps(val)
d[fld.name] = val
except (TypeError, ValueError):
pass
elif isinstance(v, BaseConfig):
d[fld.name] = v.to_dict()
elif hasattr(v, "__dataclass_fields__"):
sub = {}
for f in fields(v):
a = getattr(v, f.name)
sub[f.name] = list(a) if isinstance(a, tuple) else a
d[fld.name] = sub
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
return target_type.from_dict(value)
raise TypeError
@classmethod
def from_json(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -1,92 +1,42 @@
import json import json
from dataclasses import dataclass from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional, Self from typing import Optional, Self
from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
class ConfigFactory(BaseFactory[BaseConfig]):
"""Factory that dispatches config classes by ``model_type``."""
@classmethod
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
model_type = raw.get("model_type") or "autoregressive_lm"
config_cls = cls.get_component_class(model_type)
return config_cls.from_dict(raw)
@dataclass @dataclass
class BaseModelConfig(BaseConfig): class ModelConfig:
"""Base config with ``model_type`` dispatch and file I/O.""" # basic config
model_type: Optional[str] = None model_type: Optional[str] = None
@classmethod
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
return cls.from_dict(raw)
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@dataclass
@ConfigFactory.register("autoregressive_lm")
class AutoRegressiveLMConfig(BaseModelConfig):
"""Configuration for autoregressive language model."""
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
dim: Optional[int] = None dim: Optional[int] = None
n_layers: Optional[int] = None n_layers: Optional[int] = None
norm_eps: Optional[float] = None norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
attn_type: str = "gqa" # GQA
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
kv_lora_rank: Optional[int] = None def load(self, config_path: str) -> Self:
qk_nope_head_dim: Optional[int] = None config = {}
qk_rope_head_dim: Optional[int] = None with open(config_path, "r") as f:
config.update(json.load(f))
ffn_type: str = "mlp" for key, value in config.items():
n_routed_experts: Optional[int] = None if hasattr(self, key):
n_shared_experts: Optional[int] = None setattr(self, key, value)
n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None
return self
@dataclass def save(self, config_path: str):
@ConfigFactory.register("embedding") config_dict = {k: v for k, v in asdict(self).items() if v is not None}
class EncoderConfig(BaseModelConfig): with open(config_path, "w") as f:
"""Configuration for embedding encoder model.""" json.dump(config_dict, f, indent=4)
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
pooling_type: Optional[str] = None
normalize_embeddings: Optional[bool] = None

View File

@ -1,109 +0,0 @@
"""Pipeline configuration for JSONL preprocessing.
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
modes, both driven declaratively through ``input.sections`` or
``input.sources``.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
"""Declarative input mapping.
Single-output mode (backward-compatible)::
{"input": {"sections": [{"field": "messages", ...}]}}
Multi-output mode (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [{"field": "chosen", ...}]},
"rejected": {"sections": [{"field": "rejected", ...}]},
}}}
"""
sections: Optional[List[Dict]] = None
sources: Optional[Dict[str, Dict]] = None
@dataclass
class ProcessingConfig(BaseConfig):
"""Processing configuration.
Parameters
----------
max_seq_len : int
Maximum sequence length (default: 2048).
min_chars : int
Minimum number of characters to keep (default: 50).
max_chars : int
Maximum number of characters to keep (default: 2_000_000).
max_items : Optional[int]
Maximum number of items to process (default: None, unlimited).
packing_strategy : str
How to pack sequences into a contiguous stream.
- ``"simple"``: sequential concatenation (default, backward compatible).
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
max_packed_len : int
Maximum length of a packed bin. Sequences longer than this are
truncated or split depending on ``packing_strategy`` (default: 8192).
truncation_mode : str
How to truncate sequences longer than ``max_packed_len``.
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
"""
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
max_items: Optional[int] = None
packing_strategy: str = "simple"
max_packed_len: int = 8192
truncation_mode: str = "keep_start"
@dataclass
class OutputConfig(BaseConfig):
"""Output configuration.
Parameters
----------
domain_key : Optional[str]
Domain key for the output store (default: None).
storage_format : str
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
max_tokens_per_shard : int
Maximum tokens per shard before splitting (default: 100_000_000).
dtype : Dict[str, str]
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
position_ids_mode : Optional[str]
How to compute position_ids in packed sequences.
- ``None`` / ``"none"``: do not generate (backward compatible).
- ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
"""
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
position_ids_mode: Optional[str] = None
@dataclass
class PipelineConfig(BaseConfig):
version: int = 1
input: InputConfig = field(default_factory=InputConfig)
mask: Dict[str, str] = field(default_factory=dict)
mask_default: str = "mask"
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
output: OutputConfig = field(default_factory=OutputConfig)

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch.nn as nn import torch.nn as nn
@ -6,44 +6,27 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
from astrai.model.components.lora import LoRAConfig
def required(**kw):
return {"required": True, **kw}
@dataclass @dataclass
class TrainConfig(BaseConfig): class TrainConfig:
# basic setting # basic setting
model_fn: Callable[[], nn.Module] = field( model: nn.Module = field(default=None, metadata={"help": "Model for training."})
default=None, metadata=required(help="Model factory for training.") strategy: str = field(default=None, metadata={"help": "Training strategy."})
) dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(
default=None, metadata=required(help="Dataset for training.")
)
optimizer_fn: Callable[[nn.Module], Optimizer] = field( optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata=required(help="Optimizer factory for training.") default=None, metadata={"help": "Optimizer factory for training."}
) )
scheduler_fn: Callable[[Optimizer], LRScheduler] = field( scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata=required(help="Scheduler factory for training.") default=None, metadata={"help": "Scheduler factory for training."}
) )
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field( batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
default=4, metadata={"help": "Batch size per device."} accumulation_steps: int = field(
)
grad_accum_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."} default=1, metadata={"help": "Number of iterations between steps."}
) )
max_grad_norm: float = field( max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."} default=1.0, metadata={"help": "Maximum gradient norm."}
) )
gradient_checkpointing_modules: list = field(
default_factory=list,
metadata={"help": "Module types to enable activation checkpointing for."},
)
# checkpoint setting # checkpoint setting
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."}) start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
@ -57,25 +40,6 @@ class TrainConfig(BaseConfig):
default=5000, metadata={"help": "Number of iterations between checkpoints."} default=5000, metadata={"help": "Number of iterations between checkpoints."}
) )
# lora setting
lora: Optional[LoRAConfig] = field(
default=None,
metadata={"help": "LoRA config. None means full fine-tuning."},
)
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting # dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."}) random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field( num_workers: int = field(
@ -102,37 +66,20 @@ class TrainConfig(BaseConfig):
master_port: str = field( master_port: str = field(
default="29500", metadata={"help": "Master port for distributed training."} default="29500", metadata={"help": "Master port for distributed training."}
) )
parallel_mode: str = field( parallel_wrapper: Optional[Callable] = field(
default="none", default=None, metadata={"help": "Parallel function for training."}
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
) )
start_method: str = field( state_dict_fn: Optional[Callable] = field(
default="spawn", default=None, metadata={"help": "Parallel function for state dict saving."}
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
) )
# others # others
device_ids: Optional[List[int]] = field(
default=None, metadata={"help": "Device ids for distributed training."}
)
device_type: str = field( device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."} default="cuda", metadata={"help": "Device type for distributed training."}
) )
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_split: Optional[float] = field(
default=None,
metadata={
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
},
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
executor_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
)
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."} default_factory=dict, metadata={"help": "Other arguments."}
) )
@ -141,6 +88,14 @@ class TrainConfig(BaseConfig):
self.validate() self.validate()
def validate(self): def validate(self):
for fld in fields(self): required_fields = [
if fld.metadata.get("required") and getattr(self, fld.name) is None: "model",
raise ValueError(f"TrainConfig.{fld.name} is required but got None.") "strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")

View File

@ -1,31 +1,19 @@
from astrai.dataset.dataset import ( from astrai.dataset.dataset import (
BaseDataset, BaseDataset,
BaseSegmentFetcher,
DatasetFactory, DatasetFactory,
MultiSegmentFetcher,
) )
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
H5Store,
MmapStore,
Store,
StoreFactory,
detect_format,
load_bin,
load_h5,
save_bin,
save_h5,
)
__all__ = [ __all__ = [
# Base classes
"BaseDataset", "BaseDataset",
# Factory
"DatasetFactory", "DatasetFactory",
"Store", # Fetchers
"StoreFactory", "BaseSegmentFetcher",
"H5Store", "MultiSegmentFetcher",
"MmapStore", # Sampler
"detect_format",
"save_h5",
"load_h5",
"save_bin",
"load_bin",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -1,86 +1,140 @@
"""Dataset implementations with factory pattern for training.""" """Dataset implementations with factory pattern for training."""
import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.dataset.storage import (
Store,
StoreFactory,
detect_format,
)
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.serialization import load_h5
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx).
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
Returns:
Concatenated tensor of data in the specified range
"""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
# Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
data = self.segments[i][start:end]
result_segments.append(data)
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys.
Args:
begin_idx: Starting index
end_idx: Ending index
keys: Single key or list of keys to fetch
Returns:
Dictionary of tensors if multiple keys, single tensor if one key
"""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types. """Abstract base class for all dataset types.
Implements common functionality for window-based data fetching. Implements common functionality for window-based data fetching.
Uses a storage abstraction for format-agnostic data loading.
""" """
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__() super().__init__()
self.segments = {}
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.storage: Optional[Store] = None self.total_samples = None
self.fetcher: Optional[MultiSegmentFetcher] = None
@property def load(self, load_path: str):
def required_keys(self) -> List[str]: """Load dataset from HDF5 file.
"""Return required storage keys for this dataset type.
Subclasses should override to specify expected keys.
"""
return []
def _validate_keys(self):
if not self.required_keys:
return
actual_keys = set(self.storage.keys)
missing = [k for k in self.required_keys if k not in actual_keys]
if missing:
raise KeyError(
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
f"Missing: {missing}"
)
def load(self, load_path: str, storage_type: Optional[str] = None):
"""Load dataset from the given path.
Auto-detects the storage format if not specified.
Args: Args:
load_path: Path to the data directory or file load_path: Path to the HDF5 data file
storage_type: Force a specific storage type ("h5", "bin"),
or None for auto-detection
Raises:
KeyError: If the loaded storage is missing required keys.
""" """
if storage_type is None: self.segments = load_h5(load_path)
storage_type = detect_format(load_path) self.fetcher = MultiSegmentFetcher(self.segments)
self.storage = StoreFactory.create(storage_type) self.total_samples = len(self.fetcher)
self._load_path = load_path
self.storage.load(load_path)
self._validate_keys()
@property
def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset."""
if self.storage is None:
return 0
return len(self.storage)
@property
def keys(self) -> List[str]:
"""Return the available data keys."""
if self.storage is None:
return []
return self.storage.keys
def get_index(self, index: int) -> tuple: def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample. """Calculate begin and end indices for a sample.
@ -91,16 +145,10 @@ class BaseDataset(Dataset, ABC):
Returns: Returns:
Tuple of (begin_idx, end_idx) Tuple of (begin_idx, end_idx)
""" """
if self.storage is None: assert self.total_samples > self.window_size
raise RuntimeError("Dataset not loaded, call load() first")
total = len(self.storage)
if total <= self.window_size:
raise ValueError(
f"Data too short: {total} tokens <= window_size {self.window_size}"
)
begin_idx = min(index * self.stride, total - 1 - self.window_size) begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, total - 1) end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
return begin_idx, end_idx return begin_idx, end_idx
@ -113,12 +161,10 @@ class BaseDataset(Dataset, ABC):
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int: def __len__(self) -> int:
if self.storage is None: assert self.total_samples is not None
if self.total_samples <= self.window_size:
return 0 return 0
total = len(self.storage) return (self.total_samples - 1 - self.window_size) // self.stride + 1
if total <= self.window_size:
return 0
return (total - 1 - self.window_size) // self.stride + 1
class DatasetFactory(BaseFactory["BaseDataset"]): class DatasetFactory(BaseFactory["BaseDataset"]):
@ -137,7 +183,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
""" """
@classmethod @classmethod
def _validate_component(cls, dataset_cls: type): def _validate_component(cls, dataset_cls: type) -> None:
"""Validate that the dataset class inherits from BaseDataset.""" """Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset): if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
@ -163,7 +209,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: str, load_path: str,
window_size: int, window_size: int,
stride: Optional[int] = None, stride: Optional[int] = None,
storage_type: Optional[str] = None,
) -> "BaseDataset": ) -> "BaseDataset":
"""Create and load a dataset in one step. """Create and load a dataset in one step.
@ -172,7 +217,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file load_path: Path to the data file
window_size: Window size for data sampling window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size) stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "bin") or None for auto-detection
Returns: Returns:
Loaded dataset instance Loaded dataset instance
@ -181,7 +225,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size stride = window_size
dataset = cls.create(train_type, window_size, stride) dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type) dataset.load(load_path)
return dataset return dataset
@ -191,6 +235,10 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
return cls.list_registered() return cls.list_registered()
# ============== Dataset Classes ==============
# All dataset classes are registered at class definition time using the decorator
@DatasetFactory.register("seq") @DatasetFactory.register("seq")
class SEQDataset(BaseDataset): class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training.""" """Dataset for sequential next-token prediction training."""
@ -198,12 +246,8 @@ class SEQDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence"]
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, "sequence") return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -221,27 +265,21 @@ class SFTDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence", "loss_mask", "position_ids"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence") x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence") y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids") dtype=torch.long
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask") )
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return { return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
"input_ids": x.to(dtype=torch.long),
"target_ids": y.to(dtype=torch.long),
"position_ids": position_ids.to(dtype=torch.long),
"loss_mask": loss_mask.to(dtype=torch.bool),
}
@DatasetFactory.register("dpo") @DatasetFactory.register("dpo")
@ -251,12 +289,8 @@ class DPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int): def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -285,21 +319,15 @@ class GRPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["prompts", "responses", "masks", "rewards"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long) prompts = self._fetch_data(begin_idx, end_idx, "prompts")
responses = self._fetch_data(begin_idx, end_idx, "responses").to( responses = self._fetch_data(begin_idx, end_idx, "responses")
dtype=torch.long masks = self._fetch_data(begin_idx, end_idx, "masks")
)
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
rewards = self._fetch_data(begin_idx, end_idx, "rewards") rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return { return {

View File

@ -43,7 +43,6 @@ class ResumableDistributedSampler(Sampler[int]):
offset = 0 if drop_last else self.num_replicas - 1 offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas
self.iter = self.iter % self.num_samples_per_replica
self._indices = None self._indices = None
@ -75,10 +74,5 @@ class ResumableDistributedSampler(Sampler[int]):
self.epoch += 1 self.epoch += 1
self._indices = None self._indices = None
@property
def _remaining(self):
remaining = self.num_samples_per_replica - self.iter
return max(remaining, 0)
def __len__(self): def __len__(self):
return self._remaining return self.num_samples_per_replica

View File

@ -1,271 +0,0 @@
"""Storage backends for different data formats.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
"""
import bisect
import glob
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union
import h5py
import numpy as np
import torch
from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
with open(os.path.join(file_path, "meta.json"), "w") as f:
json.dump(meta, f)
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
with open(os.path.join(file_path, "meta.json"), "r") as f:
meta = json.load(f)
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r+",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
Args:
load_path: Directory or file path
Returns:
Format string ("h5" or "bin")
Raises:
FileNotFoundError: If no supported data files are found
"""
root = Path(load_path)
if root.is_file():
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = [
Path(p)
for pattern in ("*.h5", "*.hdf5")
for p in glob.glob(str(root / "**" / pattern), recursive=True)
]
if h5_files:
return "h5"
bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
if bin_files:
has_meta = (root / "meta.json").exists() or len(
[Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
) > 0
if has_meta:
return "bin"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
"""
def __init__(self):
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod
def load(self, path: str) -> None:
raise NotImplementedError
@property
def keys(self) -> List[str]:
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = (
min((cum[-1] if cum else 0) for cum in self._cum.values())
if self._cum
else 0
)
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
...
"""
@classmethod
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StoreFactory.register("h5")
class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str):
self._normalize(load_h5(path))
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str):
self._mmap_refs = []
root = Path(path)
all_raw: Dict[str, List[Tensor]] = {}
meta_paths = [
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
]
for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items():
if key not in all_raw:
all_raw[key] = []
all_raw[key].extend(tensors)
if not meta_paths:
raise FileNotFoundError(f"No meta.json found under {path}")
self._normalize(all_raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)

View File

@ -1,6 +1,5 @@
"""Base factory class for extensible component registration.""" """Base factory class for extensible component registration."""
import inspect
from abc import ABC from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
@ -23,7 +22,7 @@ class Registry:
component_cls: Type, component_cls: Type,
category: Optional[str] = None, category: Optional[str] = None,
priority: int = 0, priority: int = 0,
): ) -> None:
"""Register a component class with optional category and priority.""" """Register a component class with optional category and priority."""
if name in self._entries: if name in self._entries:
raise ValueError(f"Component '{name}' is already registered") raise ValueError(f"Component '{name}' is already registered")
@ -123,10 +122,6 @@ class BaseFactory(ABC, Generic[T]):
def create(cls, name: str, *args, **kwargs) -> T: def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name. """Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args: Args:
name: Registered name of the component name: Registered name of the component
*args: Positional arguments passed to component constructor *args: Positional arguments passed to component constructor
@ -144,21 +139,10 @@ class BaseFactory(ABC, Generic[T]):
f"Supported types: {sorted(cls._registry.list_names())}" f"Supported types: {sorted(cls._registry.list_names())}"
) )
component_cls = cls._registry.get(name) component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs) return component_cls(*args, **kwargs)
@classmethod @classmethod
def _validate_component(cls, component_cls: Type[T]): def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory. """Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation. Override this method in subclasses to add custom validation.
@ -171,26 +155,6 @@ class BaseFactory(ABC, Generic[T]):
""" """
pass pass
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
return cls._registry.get(name)
@classmethod @classmethod
def list_registered(cls) -> list: def list_registered(cls) -> list:
"""List all registered component names. """List all registered component names.

View File

@ -1,85 +1,25 @@
"""Inference module for continuous batching. """Inference module for continuous batching."""
Layers: from astrai.inference.engine import (
- core/: Core inference loop (cache, executor, scheduler, task) GenerationRequest,
- api/: HTTP orchestration (ProtocolHandler, server) InferenceEngine,
- protocols/: Response builders (OpenAI, Anthropic)
- transport/: SSE transport utilities
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
"""
from astrai.inference.api import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
GenContext,
MessagesRequest,
ProtocolHandler,
StopChecker,
get_app,
run_server,
) )
from astrai.inference.api.anthropic import AnthropicResponseBuilder from astrai.inference.scheduler import (
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.core import (
STOP,
Allocator,
Executor,
InferenceScheduler, InferenceScheduler,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
Task, Task,
TaskManager,
TaskStatus, TaskStatus,
TaskTable, apply_sampling_strategies,
page_hash,
)
from astrai.inference.engine import GenerationRequest, InferenceEngine
from astrai.inference.sample import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
) )
__all__ = [ __all__ = [
# Engine
"InferenceEngine", "InferenceEngine",
"GenerationRequest", # Scheduler
"InferenceScheduler", "InferenceScheduler",
"Executor",
"STOP",
"Task", "Task",
"TaskManager",
"TaskStatus", "TaskStatus",
"Allocator", # Request
"KVCache", "GenerationRequest",
"KvcacheView", # Sampling
"PagePool", "apply_sampling_strategies",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
"sample",
"BaseSamplingStrategy",
"TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
"ProtocolHandler",
"StopChecker",
"GenContext",
"OpenAIResponseBuilder",
"AnthropicResponseBuilder",
"ChatMessage",
"ChatCompletionRequest",
"AnthropicMessage",
"MessagesRequest",
"get_app",
"run_server",
] ]

View File

@ -1,27 +0,0 @@
"""Inference API: protocol handler, stop checker, and FastAPI server.
``app`` is no longer a module-level global. Use :func:`get_app` to access the
lazy singleton FastAPI instance.
"""
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
from astrai.inference.api.server import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
get_app,
run_server,
)
__all__ = [
"ProtocolHandler",
"StopChecker",
"GenContext",
"AnthropicMessage",
"ChatCompletionRequest",
"ChatMessage",
"MessagesRequest",
"get_app",
"run_server",
]

View File

@ -1,141 +0,0 @@
"""Anthropic message completion response builder."""
import time
import uuid
from typing import Any, Dict, List, Tuple, Union
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class AnthropicResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages: List[Dict[str, str]] = []
system = getattr(request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in request.messages:
text = _extract_text(m.content)
if text:
messages.append({"role": m.role, "content": text})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
ctx = GenContext(
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=int(time.time()),
model=request.model,
prompt_tokens=0,
)
stop_sequences = getattr(request, "stop_sequences", None) or []
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
events: List[str] = []
if stop.matched:
trimmed = stop.body[: stop.body.rfind(stop.matched)]
unyielded = trimmed[len(stop.yielded) :]
if unyielded:
events.append(
sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": unyielded},
},
event="content_block_delta",
)
)
events.append(
sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
if stop.matched:
content = content[: content.rfind(stop.matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -1,140 +0,0 @@
"""OpenAI chat completion response builder."""
import logging
import time
import uuid
from typing import Any, Dict, List, Tuple
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__)
_UNSUPPORTED_PARAMS = (
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
)
class OpenAIResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages = [{"role": m.role, "content": m.content} for m in request.messages]
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model
for param in _UNSUPPORTED_PARAMS:
value = getattr(request, param, None)
fields = getattr(type(request), "model_fields", {})
default = fields[param].default if param in fields else None
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
ctx = GenContext(
resp_id=self._resp_id,
created=int(time.time()),
model=self._model,
prompt_tokens=0,
)
stop = request.stop
stop_sequences = (
[] if stop is None else [stop] if isinstance(stop, str) else stop
)
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}

View File

@ -1,182 +0,0 @@
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
ProtocolHandler orchestrates the async generation loop and delegates
protocol-specific formatting to a ResponseBuilder.
"""
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
def sse_done() -> str:
return "data: [DONE]\n\n"
@dataclass
class GenContext:
"""Per-generation metadata passed to builder format methods."""
resp_id: str
created: int
model: str
prompt_tokens: int
completion_tokens: int = 0
@dataclass
class StopInfo:
"""Stop-check result passed to format_stream_end / format_response."""
matched: Optional[str] = None
body: str = ""
yielded: str = ""
class StopChecker:
"""Scans accumulated text for stop sequence matches."""
def __init__(self, sequences: List[str]):
self._sequences = [s for s in sequences if s]
def check(self, text: str) -> Optional[str]:
for seq in self._sequences:
if seq in text:
return seq
return None
class ResponseBuilder(ABC):
"""Interface for protocol-specific response formatting.
A new protocol requires one concrete builder implementing 5 methods.
"""
@abstractmethod
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
"""Return (prompt, ctx, stop_sequences) for a generation request."""
@abstractmethod
def format_stream_start(self, ctx: GenContext) -> List[str]:
"""SSE events that open the stream."""
@abstractmethod
def format_chunk(self, token: str) -> str:
"""SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
"""SSE events that close the stream."""
@abstractmethod
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
"""JSON response body for non-streaming mode."""
class ProtocolHandler:
"""Orchestrates the generation loop, delegates formatting to a builder.
Usage::
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
response = await handler.handle()
"""
def __init__(
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
):
self.request = request
self.engine = engine
self.builder = builder
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
agen = self.engine.generate_async(
prompt=prompt,
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
top_k=self.request.top_k,
)
if self.request.stream:
return self._handle_stream(agen, ctx, stop_sequences)
else:
return await self._handle_non_stream(agen, ctx, stop_sequences)
def _handle_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> StreamingResponse:
checker = StopChecker(stop_sequences)
async def event_stream():
for event in self.builder.format_stream_start(ctx):
yield event
body = ""
yielded = ""
matched = None
async for token in agen:
body += token
matched = checker.check(body)
if matched:
break
ctx.completion_tokens += 1
yield self.builder.format_chunk(token)
yielded += token
stop = StopInfo(matched=matched, body=body, yielded=yielded)
for event in self.builder.format_stream_end(ctx, stop):
yield event
yield sse_done()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
async def _handle_non_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> Dict[str, Any]:
checker = StopChecker(stop_sequences)
chunks: List[str] = []
body = ""
matched = None
async for token in agen:
chunks.append(token)
body += token
matched = checker.check(body)
if matched:
break
ctx.completion_tokens += 1
content = "".join(chunks)
stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop)

View File

@ -1,187 +0,0 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
``app`` is lazily constructed importing this module does NOT create a FastAPI instance.
Use :func:`get_app` to access the singleton.
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException
from pydantic import BaseModel, Field
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import ProtocolHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_app_instance: Optional[FastAPI] = None
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
router = APIRouter()
def _create_engine(
param_path: Path,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
def get_app() -> FastAPI:
"""Return the singleton FastAPI instance (lazily created on first call)."""
global _app_instance
if _app_instance is None:
_app_instance = FastAPI(
title="AstrAI Inference Server",
version="0.2.0",
lifespan=lifespan,
)
_app_instance.include_router(router)
_app_instance.state.server_config = {}
_app_instance.state.engine = None
return _app_instance
def _get_engine() -> InferenceEngine:
engine = get_app().state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@router.get("/health")
async def health():
app = get_app()
return {
"status": "ok",
"model_loaded": app.state.engine is not None,
}
@router.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@router.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
return await handler.handle()
@router.post("/v1/messages")
async def create_message(request: MessagesRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
return await handler.handle()
def run_server(
param_path: Path,
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
app = get_app()
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
app,
host=host,
port=port,
reload=reload,
)

View File

@ -1,32 +0,0 @@
"""Inference core: cache, executor, scheduler, task management."""
from astrai.inference.core.cache import (
Allocator,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
from astrai.inference.core.executor import Executor
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
__all__ = [
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
"Executor",
"InferenceScheduler",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
]

View File

@ -1,368 +0,0 @@
import threading
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class Allocator:
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self.on_evict: Optional[Callable[[int], None]] = None
self._lock = threading.Lock()
def alloc(self) -> int:
with self._lock:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self.on_evict:
self.on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False):
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
self._lru[idx] = None
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int):
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
def ref_count(self, idx: int) -> int:
with self._lock:
return self._refs[idx]
def touch(self, idx: int):
with self._lock:
self._lru.move_to_end(idx)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
def __init__(self, page_size: int):
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int):
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
with self._lock:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int]) -> List[int]:
with self._lock:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self._page_size)
p = self._hash_to_page.get(h)
if p is None:
break
hits.append(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
class PagePool:
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
def __init__(self, allocator: Allocator, prefix: PrefixCache):
self._alloc = allocator
self._prefix = prefix
self._alloc.on_evict = prefix.evict
@property
def allocator(self) -> Allocator:
return self._alloc
@property
def prefix(self) -> PrefixCache:
return self._prefix
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int):
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int):
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
hits = self._prefix.lookup(token_ids)
for p in hits:
self._alloc.touch(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self._prefix.record(page_idx, token_ids, logical_page_idx)
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, page_size: int):
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int):
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
with self._lock:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
with self._lock:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def get_ref(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.setdefault(task_id, [])
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
with self._lock:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class Storage:
"""KV-cache tensor storage with paged write/gather."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def write(
self,
layer_id: int,
page_table: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
):
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
valid = phys_pages >= 0
if not valid.all():
if valid.any():
valid_pages = phys_pages[valid]
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
valid, written : written + chunk
]
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
valid, written : written + chunk
]
written += chunk
continue
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(
self, layer_id: int, page_table: Tensor, total_len: int
) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
if (page_table < 0).any():
invalid = (
(page_table < 0)
.unsqueeze(-1)
.expand(-1, -1, self.page_size)
.flatten(1, 2)
)
invalid = invalid[:, :, None, None].expand_as(k)
k = k.masked_fill(invalid, 0.0)
v = v.masked_fill(invalid, 0.0)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v
class KvcacheView:
"""Bundles Storage + page_table + total_len for attention layers."""
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
self._storage = storage
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor):
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._storage.gather(layer_id, self._page_table, self._total_len)
class KVCache:
"""Facade: page management + KV-cache I/O for continuous batching."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
self._table = TaskTable(page_size)
self._storage = Storage(
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._pool.lookup(prompt_ids)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self._pool.free(hp)
for np in new_pages:
self._pool.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str):
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
):
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._pool.record(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
return KvcacheView(self._storage, page_table, total_len)

View File

@ -1,94 +0,0 @@
import logging
from typing import List, Optional
import torch
from astrai.inference.core.cache import KVCache
from astrai.inference.core.task import Task
from astrai.inference.sample import sample
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class Executor:
"""Model forward passes for prefill and decode phases."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: KVCache,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.model = model
self.tokenizer = tokenizer
self.page_cache = page_cache
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
if start_pos >= prompt_len:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
dtype=torch.long,
device=self.device,
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
def execute_decode(self, tasks: List[Task]) -> List[int]:
if not tasks:
return []
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
position_ids = torch.tensor(
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
)
total_len = position_ids.max().item() + 1
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
position_ids=position_ids.unsqueeze(1),
)
logits = outputs["logits"][:, -1, :]
return sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()

View File

@ -1,212 +0,0 @@
import logging
import threading
from typing import Any, Dict, List, Optional, Tuple
import torch
from astrai.inference.core.cache import KVCache
from astrai.inference.core.executor import Executor
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class InferenceScheduler:
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
n_pages = (
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
self._page_cache = KVCache(
config.n_layers,
n_pages,
page_size,
config.n_kv_heads,
config.dim // config.n_heads,
self.device,
self.dtype,
)
self._task_mgr = TaskManager(
tokenizer=tokenizer,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len,
)
self._executor = Executor(
model=model,
tokenizer=tokenizer,
page_cache=self._page_cache,
device=self.device,
dtype=self.dtype,
)
self._running = False
self._fatal_error: Optional[Exception] = None
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str):
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished:
self._page_cache.task_free(task.task_id)
active = self._task_mgr.get_active_tasks()
available = self._task_mgr.max_batch_size - len(active)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
for task in candidates:
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
self._task_mgr.activate(task)
else:
failed.append(task)
if failed:
self._task_mgr.return_to_waiting(failed)
if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0)
continue
to_prefill = [
t
for t in self._task_mgr.get_active_tasks()
if t.output_tokens == 0
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
]
if to_prefill:
for t in to_prefill:
t.input_tokens = len(t.prompt_ids)
groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill:
key = (
len(t.prompt_ids),
self._page_cache.task_cached(t.task_id),
)
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos)
start_logical_page = start_pos // self._page_cache.page_size
for t in group:
self._page_cache.task_record_hashes(
t.task_id,
t.prompt_ids,
start_logical_page=start_logical_page,
)
pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.get_active_tasks():
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
valid: List[Task] = []
for t in group:
if self._page_cache.task_extend(t.task_id, t.next_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if valid:
next_tokens = self._executor.execute_decode(valid)
for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
extend_ok = self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
if not extend_ok:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
for t in valid:
if t.is_finished(stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
except Exception as e:
self._fatal_error = e
self._running = False
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
def start(self):
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self):
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -1,209 +0,0 @@
import logging
import threading
import time
import uuid
from collections import deque
from enum import Enum
from typing import Any, Callable, Deque, Dict, List, Optional
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
STOP = object()
class TaskStatus(Enum):
"""Task lifecycle states."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Single generation request: prompt, sampling params, output state."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
class TaskManager:
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
def __init__(
self,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: int = 8192,
max_prompt_len: int = 512,
):
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.max_prompt_len = max_prompt_len
self.waiting_queue: Deque[Task] = deque()
self.active_tasks: List[Task] = []
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) >= self.max_seq_len:
if stream_callback:
stream_callback(STOP)
return task_id
if max_tokens is None:
max_tokens = self.max_seq_len - len(prompt_ids)
else:
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> List[Task]:
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = deque(
t for t in self.waiting_queue if t.task_id != task_id
)
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
return removed_active
def get_stats(self) -> Dict[str, Any]:
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
with self._lock:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
return finished
def pull_candidates(self, n: int) -> List[Task]:
to_add: List[Task] = []
with self._lock:
take = min(n, len(self.waiting_queue))
for _ in range(take):
to_add.append(self.waiting_queue.popleft())
return to_add
def activate(self, task: Task):
task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]):
with self._lock:
for task in reversed(tasks):
self.waiting_queue.appendleft(task)
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0):
with self._lock:
if self.waiting_queue or self.active_tasks:
return
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:
with self._lock:
return list(self.active_tasks)
def get_waiting_tasks(self) -> List[Task]:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self):
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self):
self._task_event.set()

View File

@ -1,66 +1,17 @@
"""Unified inference engine for continuous batching.""" """Unified inference engine."""
import asyncio
import gc import gc
import logging
import threading import threading
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.core.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
def __init__(self, count: int = 1):
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[Tuple[int, str]] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0):
with self._cond:
if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
):
raise TimeoutError(
f"Generation timeout after {timeout}s "
f"({self._completed}/{self._total} completed)"
)
def get_results(self) -> List[str]:
with self._cond:
return self.results.copy()
class GenerationRequest: class GenerationRequest:
@ -72,26 +23,73 @@ class GenerationRequest:
top_k: int = 50, top_k: int = 50,
top_p: float = 1.0, top_p: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
max_tokens: Optional[int] = None, max_len: int = 1024,
stream: bool = False, stream: bool = False,
): ):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature > 0):
raise ValueError("temperature must be a positive number")
self.messages = messages self.messages = messages
self.top_k = top_k self.top_k = top_k
self.top_p = top_p self.top_p = top_p
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_len = max_len
self.stream = stream self.stream = stream
self._validate()
def _validate(self):
"""Validate request parameters."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class _Result:
"""Unified result holder for streaming/non-streaming modes."""
def __init__(self, count: int = 1, stream: bool = False):
self._stream = stream
self._lock = threading.Lock()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
def append(self, token: str, idx: int = 0):
with self._lock:
if self._stream:
self.tokens.append(token)
else:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
self._event.set()
def pop_all(self) -> List[str]:
with self._lock:
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
with self._lock:
return self.results.copy()
class InferenceEngine: class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler.""" """Unified inference engine for continuous batching."""
def __init__( def __init__(
self, self,
@ -99,26 +97,55 @@ class InferenceEngine:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 1, max_batch_size: int = 1,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048, max_prefix_len: int = 512,
page_size: int = 128, cache_capacity: int = 1000,
): ):
"""
Initialize inference engine with separate model and tokenizer.
Args:
model: The language model for inference (nn.Module, e.g., Transformer)
tokenizer: The tokenizer for encoding/decoding text
config: Model configuration
max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len)
max_prefix_len: Maximum prefix length for cache (default: 512)
cache_capacity: Maximum number of cached prefixes (default: 1000)
"""
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.scheduler = InferenceScheduler( self.scheduler = InferenceScheduler(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len, max_prefix_len=max_prefix_len,
page_size=page_size, cache_capacity=cache_capacity,
device=device,
dtype=dtype,
) )
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start() self.scheduler.start()
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
self.shutdown() self.shutdown()
return False return False
@ -126,162 +153,139 @@ class InferenceEngine:
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
stream: bool = False, stream: bool = False,
max_tokens: Optional[int] = None, max_tokens: int = 1024,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator, str, List[str]]: abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface.
Args:
abort_on_exception: If True, abort the generation when consumer
stops iterating (GeneratorExit/StopIteration). Default: True.
"""
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming( return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts,
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
) )
else: else:
return self._generate_non_streaming( return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts, is_batch, max_tokens, temperature, top_p, top_k
) )
def generate_async(
self,
prompt: str,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
try:
return next(gen)
except StopIteration:
return None
def generate_with_request( def generate_with_request(
self, request: GenerationRequest self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate( return self.generate(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
max_tokens=request.max_tokens, max_tokens=request.max_len,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
top_k=request.top_k, top_k=request.top_k,
) )
def _submit_tasks(
self,
prompts: List[str],
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = GenerateResult(count=n)
task_ids = []
for i, p in enumerate(prompts):
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=cb,
)
task_ids.append(task_id)
return result, task_ids
@staticmethod
def _make_callback(result: GenerateResult, idx: int):
def cb(token):
result.append(token, idx)
return cb
def _generate_streaming( def _generate_streaming(
self, self,
prompts: List[str], prompts: List[str],
is_batch: bool, is_batch: bool,
max_tokens: Optional[int], max_tokens: int,
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Generator: abort_on_exception: bool = True,
result, task_ids = self._submit_tasks( ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
prompts, max_tokens, temperature, top_p, top_k """Generate with streaming output.
Args:
abort_on_exception: If True, abort the task when generator is
stopped early by consumer (GeneratorExit/StopIteration).
"""
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
result = _Result(stream=True)
task_id = self.scheduler.add_task(
prompt=prompts[0],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
) )
n = len(prompts)
remaining = n
finished = [False] * n
def gen(): def gen():
nonlocal remaining
try: try:
while remaining > 0: while True:
items = result.pop_all() tokens = result.pop_all()
for idx, token in items: for token in tokens:
if token is STOP: if token == "[DONE]":
if not finished[idx]: return
finished[idx] = True yield token
remaining -= 1
else:
yield (idx, token) if is_batch else token
if remaining > 0:
result.wait(timeout=0.05) result.wait(timeout=0.05)
finally: except Exception:
for tid in task_ids: # Consumer stopped iterating - abort the task
self.scheduler.remove_task(tid) if abort_on_exception:
self.scheduler.remove_task(task_id)
raise
gen.task_id = task_id
return gen() return gen()
def _generate_non_streaming( def _generate_non_streaming(
self, self,
prompts: List[str], prompts: List[str],
is_batch: bool, is_batch: bool,
max_tokens: Optional[int], max_tokens: int,
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks( """Generate without streaming."""
prompts, max_tokens, temperature, top_p, top_k result = _Result(count=len(prompts))
for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
return callback
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
) )
try: result.wait()
result.wait_completion() results = result.get_results()
except TimeoutError: return results if is_batch else results[0]
for tid in task_ids:
self.scheduler.remove_task(tid)
raise
for tid in task_ids:
self.scheduler.remove_task(tid)
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self): def shutdown(self) -> None:
"""Shutdown the engine and release all resources."""
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,190 +0,0 @@
"""Composable sampling strategies for logit transformation.
Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline.
All strategies accept both scalar and per-sample tensor
parameters, so a single pipeline works for any batch size.
"""
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from torch import Tensor
class BaseSamplingStrategy(ABC):
"""Abstract base for a logit transformation strategy."""
@abstractmethod
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Applies the strategy to logits.
Args:
logits: Raw logits tensor (batch, vocab_size).
filter_value: Value assigned to filtered-out positions.
Returns:
Transformed logits tensor.
"""
class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness.
Args:
temperature: Scalar or ``[batch]`` tensor.
"""
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
t = self.temperature
if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
t = torch.clamp(t, min=1e-8)
if (t != 1.0).any():
logits = logits / t
elif t != 1.0:
logits = logits / max(t, 1e-8)
return logits
class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value.
Args:
top_k: Scalar or ``[batch]`` tensor (0 disables).
"""
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
tk = self.top_k
if isinstance(tk, Tensor):
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
max_k = int(tk.max().item())
if max_k <= 0:
return logits
max_k = min(max_k, logits.size(-1))
values, _ = torch.topk(logits, max_k, dim=-1)
per_row_k = tk.clamp(max=max_k)
thresholds = torch.full_like(logits[..., -1:], -float("inf"))
positive = per_row_k > 0
if positive.any():
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
thresholds[positive] = values[
row_idx, per_row_k[positive] - 1
].unsqueeze(-1)
logits[logits < thresholds] = filter_value
return logits
if tk > 0:
k = min(tk, logits.size(-1))
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value
return logits
class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p.
Args:
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
"""
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, remove)
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
if (tp < 1.0).any():
logits = self._apply(logits, tp.view(-1, 1), filter_value)
elif tp < 1.0:
logits = self._apply(logits, tp, filter_value)
return logits
class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided,
matching the original temperature -> top-k -> top-p ordering.
Usage::
pipeline = SamplingPipeline([
TemperatureStrategy(0.8),
TopKStrategy(50),
TopPStrategy(0.95),
])
logits = pipeline.apply(logits)
token = pipeline.sample(logits) # softmax + multinomial
"""
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits
@torch.no_grad()
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Apply strategies then sample (softmax + multinomial).
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return torch.multinomial(
torch.softmax(self.apply(logits, filter_value), dim=-1),
num_samples=1,
).squeeze(-1)
@torch.inference_mode()
def sample(
logits: Tensor,
temperature: Union[float, Tensor] = 1.0,
top_k: Union[int, Tensor] = 0,
top_p: Union[float, Tensor] = 1.0,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies then sample (softmax + multinomial).
Shortcut for ``SamplingPipeline(...).sample(logits)``.
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return SamplingPipeline(
[
TemperatureStrategy(temperature),
TopKStrategy(top_k),
TopPStrategy(top_p),
]
).sample(logits, filter_value)

View File

@ -0,0 +1,637 @@
"""Inference scheduler for continuous batching."""
import threading
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer
class RadixNode:
"""Radix tree node for prefix cache."""
def __init__(self):
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
self.hash: Optional[int] = None # 64-bit hash of the prefix
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
self.ref_count: int = 0 # number of tasks referencing this prefix
self.last_access: float = 0.0 # timestamp for LRU
self.token_sequence: list = [] # full token sequence from root to this node
class PrefixCacheManager:
"""Prefix cache manager using Radix tree with LRU eviction."""
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
self.root = RadixNode()
self.base = base
self.mod = mod
self.max_capacity = max_capacity
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
node = self.root
path = []
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
node.hash = h
path.append(token_id)
node.token_sequence = list(
path
) # store full sequence for exact verification
# Leaf node: set slot and increase ref_count
if node.slot == -1:
node.slot = slot
node.ref_count += 1
node.last_access = time.time()
self._update_lru(node)
self._evict_if_needed()
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
"""Find longest matching prefix, return (prefix_len, slot).
During traversal, compute hash per token and compare with node hash.
If hash matches, perform full token sequence verification to avoid
hash collision errors.
"""
node = self.root
best_len = 0
best_slot = -1
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
break
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
if node.hash == h: # hash matches
# Exact verification: compare full token sequence
if node.token_sequence == token_ids[: i + 1]:
best_len = i + 1
best_slot = node.slot
node.last_access = time.time()
self._update_lru(node)
if best_len > 0:
return (best_len, best_slot)
return None
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
return
node = node.children[token_id]
if node.ref_count > 0:
node.ref_count -= 1
if node.ref_count == 0:
node.slot = -1 # slot can be reused
def _update_lru(self, node: RadixNode) -> None:
"""Update LRU list, move node to most recently used position."""
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
self.lru.append((node.last_access, node))
def _evict_if_needed(self) -> None:
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
if len(self.lru) <= self.max_capacity:
return
# Sort by timestamp
self.lru.sort(key=lambda x: x[0])
for ts, node in self.lru:
if node.ref_count == 0:
# Remove leaf node from tree (need to recursively delete empty branches)
self._remove_node(node)
self.lru.remove((ts, node))
if len(self.lru) <= self.max_capacity:
break
def _remove_node(
self,
node: RadixNode,
parent: Optional[RadixNode] = None,
child_key: Optional[int] = None,
) -> None:
"""Remove node from tree, including empty parent nodes."""
# First, recursively remove all children
for child_key, child_node in list(node.children.items()):
self._remove_node(child_node, node, child_key)
# Clear the node's leaf properties
node.slot = -1
node.hash = None
node.token_sequence = []
node.children.clear()
# If this node has no children and has a parent, remove the reference from parent
if parent is not None and child_key is not None and len(node.children) == 0:
if child_key in parent.children:
del parent.children[child_key]
class TaskStatus:
"""Task state for continuous batching."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Individual task for continuous batching."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.slot: int = -1
self.prefix_len: int = 0 # prefix cache matched length
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
return (
bool(self.output_ids and self.output_ids[-1] in stop_ids)
or self.output_tokens >= self.max_tokens
)
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
# Clone logits to avoid inplace updates on inference tensor
logits = logits.clone()
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Inference scheduler with continuous batching support."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
config = model.config
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.max_prefix_len = max_prefix_len
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
# Initialize prefix cache
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_kv_heads,
head_dim,
),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_kv_heads,
head_dim,
),
device=self.device,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
# Truncate if exceeds max_prefix_len
if len(prompt_ids) > self.max_prefix_len:
prompt_ids = prompt_ids[: self.max_prefix_len]
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
# Find longest matching prefix from cache
match = self.prefix_cache.find_longest_prefix(prompt_ids)
if match:
prefix_len, slot = match
task.prefix_len = prefix_len
task.slot = slot
else:
task.prefix_len = 0
task.slot = -1
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock:
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
# Release prefix cache reference
if task.prefix_len > 0:
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
task.slot = -1
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return
with self._lock:
to_add = [
self.waiting_queue.pop(0)
for _ in range(min(available_slots, len(self.waiting_queue)))
]
for task in to_add:
task.slot = self._allocate_slot()
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
def _allocate_slot(self) -> int:
"""Allocate an available slot for a task."""
for i in range(self.max_batch_size):
if not any(t.slot == i for t in self.active_tasks):
return i
return -1
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase with incremental prefill support."""
if not tasks:
return
# Group tasks by prefix cache status
fully_cached, partial, full = [], [], []
for task in tasks:
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
if prefix_len == total_len:
fully_cached.append(task)
elif prefix_len > 0:
partial.append(task)
else:
full.append(task)
# Handle fully cached tasks
for t in fully_cached:
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
if full:
self._execute_full_prefill(full)
if partial:
self._execute_partial_prefill(partial)
def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Execute full prefill for tasks without prefix cache."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
)
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
# Insert new prefix into cache
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Execute incremental prefill for tasks with partial prefix cache match."""
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len >= total_len:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Get new tokens that need prefill
new_ids = task.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Build input for incremental prefill
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
# Input mask should cover from position 0 to prefix_len + new_len
# The prefix part uses cached KV, new part needs computation
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
task.input_tokens = total_len
task.output_tokens = 0
# Insert full prefix into cache (ref_count already increased in add_task)
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
with torch.inference_mode():
outputs = self.model(
input_tensor,
input_mask=active_mask,
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
"""Main generation loop."""
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
if new_tasks:
self._execute_prefill(new_tasks)
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
if decode_tasks:
self._execute_decode(decode_tasks, start_pos)
if not self.active_tasks and not self.waiting_queue:
self._task_event.wait(timeout=0.05)
self._task_event.clear()
def start(self) -> None:
"""Start the generation loop."""
if not self._running:
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
# Clear KV cache to free GPU memory
if self.kv_cache is not None:
k_cache, v_cache = self.kv_cache
if k_cache is not None:
k_cache.detach()
if v_cache is not None:
v_cache.detach()
# Clear seq mask
self.seq_mask.detach()
# Clear task lists
self.waiting_queue.clear()
self.active_tasks.clear()
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}

321
astrai/inference/server.py Normal file
View File

@ -0,0 +1,321 @@
"""
Inference Server with Continuous Batching Support
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
"""
import json
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
_server_config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Configure server settings before starting.
Args:
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
_server_config["max_batch_size"] = max_batch_size
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine
# Startup: Load model with configured settings
try:
load_model(
param_path=_server_config["param_path"],
device=_server_config["device"],
dtype=_server_config["dtype"],
max_batch_size=_server_config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown: Cleanup engine
if _engine:
_engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def load_model(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
"""Load model parameters and initialize inference engine."""
global _model_param, _engine
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
# Initialize inference engine with separate model and tokenizer
_engine = InferenceEngine(
model=_model_param,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
# Pydantic models for API request/response
class ChatMessage(BaseModel):
role: str # "user", "assistant", "system"
content: str
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: float = Field(0.8, ge=0.0, le=2.0)
top_p: float = Field(0.95, ge=0.0, le=1.0)
top_k: int = Field(50, ge=0)
max_tokens: int = Field(2048, ge=1)
stream: bool = False
system_prompt: Optional[str] = None
class CompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": _model_param is not None,
"engine_ready": _engine is not None,
}
@app.get("/stats")
async def get_stats():
"""Get inference engine statistics."""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint.
Supports both streaming and non-streaming modes with continuous batching.
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
if request.stream:
# Streaming response (use synchronous generator)
generator = _engine.generate(
prompt=prompt,
stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def generate_stream():
for token in generator:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
# Non-streaming response
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
# Build OpenAI-style response
import time
resp = CompletionResponse(
id=f"chatcmpl-{int(time.time())}",
created=int(time.time()),
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": result},
"finish_reason": "stop",
}
],
)
return resp
@app.post("/generate")
async def generate(
query: str,
history: Optional[List[List[str]]] = None,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 50,
max_len: int = 2048,
stream: bool = False,
):
"""Simple generation endpoint.
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Build messages for chat template
messages = []
if history:
# Convert history format: List[List[str]] -> List[Dict]
for h in history:
if len(h) >= 2:
messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query})
# Use tokenizer's chat template
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream:
# Synchronous streaming
result = _engine.generate(
prompt=prompt,
stream=True,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
def stream_generator():
for token in result:
yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return {"response": result}
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
uvicorn.run(
"astrai.inference.server:app",
host=host,
port=port,
reload=reload,
)

View File

@ -1,18 +1,12 @@
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.components.attention import GQA from astrai.model.module import (
from astrai.model.components.decoder_block import DecoderBlock GQA,
from astrai.model.components.linear import Linear MLP,
from astrai.model.components.lora import ( DecoderBlock,
LoRAConfig, Linear,
inject_lora, RMSNorm,
load_lora,
merge_lora,
save_lora,
) )
from astrai.model.components.mlp import MLP from astrai.model.transformer import Transformer
from astrai.model.components.norm import RMSNorm
from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
__all__ = [ __all__ = [
# Modules # Modules
@ -22,13 +16,6 @@ __all__ = [
"GQA", "GQA",
"DecoderBlock", "DecoderBlock",
# Models # Models
"AutoRegressiveLM", "Transformer",
"EmbeddingEncoder",
"AutoModel", "AutoModel",
# LoRA
"LoRAConfig",
"inject_lora",
"merge_lora",
"save_lora",
"load_lora",
] ]

View File

@ -4,22 +4,17 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Self, Union from typing import Dict, Self, Type, Union
import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config.model_config import BaseModelConfig, ConfigFactory from astrai.config import ModelConfig
from astrai.factory import BaseFactory
from astrai.serialization import load_model_config, load_model_weights, save_model
@contextmanager @contextmanager
def _disable_random_init(enable: bool = True): def _disable_random_init(enable: bool = True):
if not enable: init_functions = [
yield
return
names = (
"xavier_normal_", "xavier_normal_",
"xavier_uniform_", "xavier_uniform_",
"kaiming_normal_", "kaiming_normal_",
@ -29,66 +24,110 @@ def _disable_random_init(enable: bool = True):
"constant_", "constant_",
"normal_", "normal_",
"uniform_", "uniform_",
) ]
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)} original_funcs = {}
for n in orig: for name in init_functions:
setattr(nn.init, n, lambda *a, **kw: None) if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try: try:
yield yield
finally: finally:
for n, fn in orig.items(): if enable:
setattr(nn.init, n, fn) for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class AutoModel(BaseFactory["AutoModel"], nn.Module): class AutoModel(nn.Module):
""" """
Autoregressive language model base class. Autoregressive language model base class.
Provides model loading/saving, registration, and generation. Provides model loading/saving and generation capabilities.
""" """
def __init__(self, config: BaseModelConfig): # Model registry - stored as class attribute
_registry: Dict[str, Type["AutoModel"]] = {}
def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@classmethod
def register(cls, model_type: str):
"""
Class method decorator to register model type.
Usage:
@AutoModel.register('transformer')
class Transformer(AutoModel):
...
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry[model_type.lower()] = sub_cls
return sub_cls
return decorator
@classmethod
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if model_type not in cls._registry:
available = list(cls._registry.keys())
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry[model_type]
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
path: Union[str, Path], path: Union[str, Path],
disable_random_init: bool = True, disable_random_init: bool = True,
strict: bool = True,
) -> nn.Module: ) -> nn.Module:
model_path = Path(path) model_path = Path(path)
# Load config
config = ModelConfig()
config_path = model_path / "config.json" config_path = model_path / "config.json"
if not config_path.exists(): if config_path.exists():
config.load(str(config_path))
else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
raw = load_model_config(str(model_path)) # If called from base class, use model_type to determine actual model class
config = ConfigFactory.load(raw) if cls is AutoModel:
model_type = config.model_type or "autoregressive_lm" model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
actual_cls = AutoModel.get_component_class(model_type) else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):
model = actual_cls(config) model = actual_cls(config)
# Load weights
weights_path = model_path / "model.safetensors" weights_path = model_path / "model.safetensors"
if weights_path.exists(): if weights_path.exists():
state_dict = load_model_weights(str(model_path)) state_dict = st.load_file(str(weights_path))
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=False)
return model return model
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, Path], save_directory: Union[str, Path],
): ) -> None:
save_model( save_path = Path(save_directory)
config=self.config.to_dict(), save_path.mkdir(parents=True, exist_ok=True)
state_dict=self.state_dict(),
save_directory=str(save_directory), # Save config
) self.config.save(str(save_path / "config.json"))
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
def to(self, *args, **kwargs) -> Self: def to(self, *args, **kwargs) -> Self:
"""Move model to device/dtype.""" """Move model to device/dtype."""

View File

@ -1,25 +0,0 @@
from astrai.model.components.attention import GQA, MLA, repeat_kv
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import (
RotaryEmbedding,
apply_rotary_emb,
get_rotary_emb,
)
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"Embedding",
"GQA",
"MLA",
"DecoderBlock",
"RotaryEmbedding",
"apply_rotary_emb",
"get_rotary_emb",
"repeat_kv",
]

View File

@ -1,212 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import apply_rotary_emb
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_heads, n_rep, head_dim)
.reshape(bs, slen, n_heads * n_rep, head_dim)
)
class AttnFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, attn_type: str, **kwargs) -> nn.Module:
return super().create(attn_type, **kwargs)
@AttnFactory.register("gqa")
class GQA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
use_qk_norm: bool,
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
assert dim % n_heads == 0
assert n_heads % n_kv_heads == 0
self.head_dim = dim // n_heads
self.layer_id = layer_id
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim)
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
if self.use_gated_attention:
self.gate = Linear(dim, dim)
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
is_causal = attn_mask is None
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
)
if self.use_gated_attention:
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
@AttnFactory.register("mla")
class MLA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
self.layer_id = layer_id
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (2 * self.head_dim),
)
self.o_proj = Linear(dim, dim, bias=False)
if use_gated_attention:
self.gate = Linear(dim, dim, bias=False)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = attn_mask is None
q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
kv_compressed = self.kv_a_proj(x)
kv_compressed = self.kv_norm(kv_compressed)
kv = self.kv_b_proj(kv_compressed)
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
k_nope, k_rope, v = torch.split(
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
)
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_nope_head_dim :],
)
q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb)
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask, is_causal=is_causal
)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention:
attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out)
return out

View File

@ -1,59 +0,0 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import AttnFactory
from astrai.model.components.mlp import FFNFactory
from astrai.model.components.norm import RMSNorm
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**kwargs,
):
super().__init__()
self.attention = AttnFactory.create(
attn_type,
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
**kwargs,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -1,16 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -1,21 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / (fan_in**0.5)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -1,194 +0,0 @@
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional, Set
import torch
import torch.nn as nn
import torch.nn.functional as F
from astrai.model.components.linear import Linear
from astrai.serialization import (
load_json,
load_safetensors,
save_json,
save_safetensors,
)
logger = logging.getLogger(__name__)
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
TARGET_MODULES_FFN = {"up", "gate", "down"}
@dataclass
class LoRAConfig:
r: int = 16
alpha: int = 32
target_modules: tuple = ("q_proj", "v_proj")
class LoRALinear(nn.Module):
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
super().__init__()
self.register_parameter("weight", base.weight)
self.weight.requires_grad_(False)
self.bias = base.bias
if self.bias is not None:
self.bias.requires_grad_(False)
self.r = r
self.scaling = alpha / r
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
self._merged = False
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
if not self._merged:
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
return out
def merge(self):
if self._merged:
return
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self._merged = True
del self.lora_A
del self.lora_B
def _collect_lora_info(model: nn.Module) -> dict:
names = {}
for n, m in model.named_modules():
if isinstance(m, Linear):
_, _, child = n.rpartition(".")
names.setdefault(child, []).append(n)
return names
def _get_lora_count(model: nn.Module) -> int:
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
def inject_lora(
model: nn.Module,
r: int = 16,
alpha: int = 32,
target_modules: Optional[Set[str]] = None,
) -> LoRAConfig:
if target_modules is None:
target_modules = TARGET_MODULES_ATTN
available = _collect_lora_info(model)
injected = 0
for name, module in list(model.named_modules()):
if not isinstance(module, Linear):
continue
parent_name, _, child_name = name.rpartition(".")
if child_name not in target_modules:
continue
parent = model.get_submodule(parent_name) if parent_name else model
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
injected += 1
if injected == 0:
logger.warning(
"No LoRA layers injected. Available Linear child names: %s. "
"target_modules: %s. Check model type and target_modules.",
sorted(available),
sorted(target_modules),
)
else:
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
def merge_lora(model: nn.Module):
n = 0
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
n += 1
if n == 0:
logger.warning("No LoRA layers to merge.")
else:
logger.info("Merged %d LoRA layers", n)
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
lora_sd = {
k: v
for k, v in model.state_dict().items()
if k.endswith((".lora_A", ".lora_B"))
}
if not lora_sd:
raise RuntimeError(
"No LoRA parameters found in model. "
"The model may not have been injected or was already merged."
)
path = Path(save_dir)
path.mkdir(parents=True, exist_ok=True)
save_safetensors(lora_sd, path / "adapter_model.safetensors")
save_json(asdict(config), path / "adapter_config.json")
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
path = Path(load_dir)
raw = load_json(path / "adapter_config.json")
config = LoRAConfig(
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
)
existing = _get_lora_count(model)
if existing > 0:
logger.warning(
"Model already has %d LoRA layers. Skipping injection, "
"loading weights onto existing layers only.",
existing,
)
else:
inject_lora(
model,
r=config.r,
alpha=config.alpha,
target_modules=set(config.target_modules),
)
weights = load_safetensors(path / "adapter_model.safetensors")
try:
missing, unexpected = model.load_state_dict(weights, strict=False)
except RuntimeError as e:
msg = str(e)
if "size mismatch" in msg:
raise RuntimeError(
f"LoRA weight shapes do not match the model. "
f"The adapter config (r={config.r}) may not match the injected layers. "
f"Original error: {msg}"
) from e
raise
injected = _get_lora_count(model)
if injected == 0:
raise RuntimeError(
"No LoRA layers found after loading. "
"Inject LoRA before calling load_lora, or check the adapter config."
)
if missing:
lora_missing = [k for k in missing if "lora" in k]
if lora_missing:
raise RuntimeError(
f"LoRA weight keys not found in model: {lora_missing}. "
f"The adapter config (r={config.r}) may not match the model."
)
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
if unexpected:
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
logger.info("LoRA adapter loaded from %s", load_dir)
return config

View File

@ -1,93 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.up = Linear(dim, dim_ffn)
self.gate = Linear(dim, dim_ffn)
self.down = Linear(dim_ffn, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@FFNFactory.register("moe")
class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_ffn: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
):
super().__init__()
self.dim = dim
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.n_activated_experts = n_activated_experts
self.topk_method = topk_method
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:
bsz, seq_len, dim = x.shape
x_flat = x.view(-1, dim)
shared_out = self._shared_forward(x_flat)
routed_out = self._routed_forward(x_flat)
out = (shared_out + routed_out).view(bsz, seq_len, dim)
return out
def _shared_forward(self, x: Tensor) -> Tensor:
if self.n_shared_experts == 0:
return torch.zeros_like(x)
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
def _routed_forward(self, x: Tensor) -> Tensor:
N, D = x.shape
K = self.n_activated_experts
router_logits = self.router(x)
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for expert_idx in range(self.n_routed_experts):
expert_mask = topk_indices == expert_idx
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
expert_input = x[token_idx]
expert_output = self.routed_experts[expert_idx](expert_input)
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
output.index_add_(0, token_idx, expert_output * weights)
return output

View File

@ -1,15 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)

View File

@ -1,71 +0,0 @@
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def ntk_base(base: float, dim: int, factor: float) -> float:
return base * (factor ** (dim / (dim - 2)))
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim: int,
max_len: int,
base: float = 10000,
rope_scaling: Optional[Dict] = None,
):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self.rope_scaling = rope_scaling
if rope_scaling is not None:
scaling_type = rope_scaling.get("type", "ntk")
factor = rope_scaling.get("factor", 1.0)
if scaling_type == "ntk":
self.base = ntk_base(base, dim, factor)
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)

View File

@ -1,99 +0,0 @@
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import EncoderConfig
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
from astrai.model.transformer import process_attention_mask
@AutoModel.register("embedding")
class EmbeddingEncoder(AutoModel):
def __init__(self, config: EncoderConfig):
super().__init__(config)
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.pooling_type = config.pooling_type or "mean"
self.normalize_embeddings = config.normalize_embeddings or False
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
state_dict = dict(state_dict)
state_dict.pop("lm_head.weight", None)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
B, S = input_ids.shape
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
hidden_states = self.norm(x)
if self.pooling_type == "cls":
pooled = hidden_states[:, 0]
elif self.pooling_type == "last":
if input_mask is not None:
lengths = input_mask.sum(dim=1) - 1
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
else:
pooled = hidden_states[:, -1]
else:
if input_mask is not None:
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
min=1.0
)
else:
pooled = hidden_states.mean(dim=1)
if self.normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
return pooled

382
astrai/model/module.py Normal file
View File

@ -0,0 +1,382 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_heads, n_rep, head_dim)
.reshape(bs, slen, n_heads * n_rep, head_dim)
)
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""
Apply rotary embedding to the input tensor using cos/sin form.
Args:
x (Tensor): The input tensor (shape [..., seq_len, dim]).
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self.max_len_cached = None
self._set_rotary_buffer(self.max_len, None)
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
class GQA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
use_qk_norm: bool,
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
assert dim % n_heads == 0
assert n_heads % n_kv_heads == 0
self.head_dim = dim // n_heads
self.layer_id = layer_id
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim)
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
if self.use_gated_attention:
self.gate = Linear(dim, dim)
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None:
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
)
if self.use_gated_attention:
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
class MLA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
self.layer_id = layer_id
self.n_rep = n_heads // n_kv_heads
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
)
self.o_proj = Linear(dim, dim, bias=False)
if use_gated_attention:
self.gate = Linear(dim, dim, bias=False)
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
kv_compressed = self.kv_a_proj(x)
kv_compressed = self.kv_norm(kv_compressed)
kv = self.kv_b_proj(kv_compressed)
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
k_nope, k_rope, v = torch.split(
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
)
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_rope_head_dim :],
)
q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb)
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention:
attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out)
return out
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
# attention
attn_output = self.attention(
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
)
x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x
return x
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -1,63 +1,83 @@
from typing import Any, Dict, Mapping, Optional from typing import Any, Mapping, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import AutoRegressiveLMConfig from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock from astrai.model.module import (
from astrai.model.components.embedding import Embedding DecoderBlock,
from astrai.model.components.linear import Linear Embedding,
from astrai.model.components.norm import RMSNorm Linear,
from astrai.model.components.rope import RotaryEmbedding RMSNorm,
RotaryEmbedding,
)
def process_attention_mask( def process_attention_mask(
seq_mask: Tensor,
input_tensor: Tensor, input_tensor: Tensor,
position_ids: Optional[Tensor], start_pos: int = 0,
input_mask: Optional[Tensor] = None,
is_causal: bool = False, is_causal: bool = False,
) -> Optional[Tensor]: ) -> Tensor:
if position_ids is None: """
return None Create attention mask for GQA
if input_mask is not None and input_mask.dim() > 2: Args:
return input_mask seq_mask (Tensor): A tensor indicating whether each position is valid or not.
input_tensor (Tensor): The input tensor.
start_pos (int): The starting position of the sequence.
is_causal (bool): Whether the attention is causal or not.
Returns:
Tensor: The attention mask tensor.
"""
device = input_tensor.device device = input_tensor.device
B = input_tensor.size(0) dtype = input_tensor.dtype
T = position_ids.max().item() + 1 seq_len = input_tensor.size(1)
if input_mask is None: if seq_mask is None:
if position_ids.min().item() == 0 and is_causal: if start_pos != 0:
return None # for single prompt chat
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device) seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else: else:
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1) return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
# (bsz, seq_len, start_pos + seq_len)
if is_causal: if is_causal:
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device) expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
attend = attend & causal
return attend.unsqueeze(1) attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@AutoModel.register("autoregressive_lm") @AutoModel.register("transformer")
class AutoRegressiveLM(AutoModel): class Transformer(AutoModel):
"""Autoregressive language model with paged KV cache.""" """
Transformer language model.
"""
def __init__(self, config: AutoRegressiveLMConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
rope_dim = (
config.qk_rope_head_dim
if config.attn_type == "mla"
else config.dim // config.n_heads
)
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding( self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling config.dim // config.n_heads, config.max_len
) )
self.embed_tokens = Embedding(config.vocab_size, config.dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
@ -72,15 +92,6 @@ class AutoRegressiveLM(AutoModel):
config.use_qk_norm, config.use_qk_norm,
config.use_gated_attention, config.use_gated_attention,
layer_id, layer_id,
attn_type=config.attn_type,
ffn_type=config.ffn_type,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.topk_method,
kv_lora_rank=config.kv_lora_rank,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
) )
for layer_id in range(config.n_layers) for layer_id in range(config.n_layers)
] ]
@ -89,28 +100,32 @@ class AutoRegressiveLM(AutoModel):
self.norm = RMSNorm(config.dim, config.norm_eps) self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size) self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight is True: if self.config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight self.lm_head.weight = self.embed_tokens.weight
self.apply(self._init_weights) self._init_weights()
def _init_weights(self, module): def _init_weights(self):
if hasattr(module, "reset_parameters"): for param in self.parameters():
module.reset_parameters() if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006)
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = "lm_head.weight" lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight" embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict) state_dict = dict(state_dict)
if self.config.tie_weight is True: if self.config.tie_weight:
# same tensor for embed and lm_head # same tensor
if embed_key in state_dict: if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
else: else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict: if lm_head_key not in state_dict and embed_key in state_dict:
# clone to avoid sharing gradients # use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)
@ -120,7 +135,7 @@ class AutoRegressiveLM(AutoModel):
destination=destination, prefix=prefix, keep_vars=keep_vars destination=destination, prefix=prefix, keep_vars=keep_vars
) )
if self.config.tie_weight is True: if self.config.tie_weight:
lm_head_key = prefix + "lm_head.weight" lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict: if lm_head_key in state_dict:
del state_dict[lm_head_key] del state_dict[lm_head_key]
@ -131,17 +146,18 @@ class AutoRegressiveLM(AutoModel):
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None, persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
position_ids: Optional[Tensor] = None, start_pos: int = 0,
) -> Dict[str, Tensor]: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, position_ids) rotary_emb = self.rotary_embedding(x, start_pos)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache) x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -1,13 +1,3 @@
from astrai.parallel.executor import (
AccumOptimizer,
AccumScheduler,
BaseExecutor,
DDPExecutor,
ExecutorFactory,
FSDPExecutor,
GradientState,
NoneExecutor,
)
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import ( from astrai.parallel.setup import (
get_current_device, get_current_device,
@ -27,12 +17,4 @@ __all__ = [
"spawn_parallel_fn", "spawn_parallel_fn",
"RowParallelLinear", "RowParallelLinear",
"ColumnParallelLinear", "ColumnParallelLinear",
"ExecutorFactory",
"BaseExecutor",
"GradientState",
"AccumOptimizer",
"AccumScheduler",
"NoneExecutor",
"DDPExecutor",
"FSDPExecutor",
] ]

View File

@ -1,272 +0,0 @@
"""Unified training executor — parallel strategy + gradient accumulation."""
import contextlib
import logging
import os
from contextlib import contextmanager
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.factory import BaseFactory
from astrai.parallel.setup import get_rank, get_world_size
logger = logging.getLogger(__name__)
class GradientState:
def __init__(self, grad_accum_steps: int = 1):
self.num_steps = max(grad_accum_steps, 1)
self._step: int = 0
self._sync_gradients: bool = True
@property
def sync_gradients(self) -> bool:
return self._sync_gradients
def _do_sync(self):
self._step += 1
self._sync_gradients = self._step % self.num_steps == 0
class AccumOptimizer:
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
self.optimizer = optimizer
self.gradient_state = gradient_state
def step(self, closure=None):
if self.gradient_state.sync_gradients:
self.optimizer.step(closure)
def zero_grad(self):
if self.gradient_state.sync_gradients:
self.optimizer.zero_grad()
@property
def param_groups(self):
return self.optimizer.param_groups
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, d):
self.optimizer.load_state_dict(d)
class AccumScheduler:
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
self.scheduler = scheduler
self.gradient_state = gradient_state
def step(self):
if self.gradient_state.sync_gradients:
self.scheduler.step()
def state_dict(self):
return self.scheduler.state_dict()
def load_state_dict(self, d):
self.scheduler.load_state_dict(d)
def get_last_lr(self):
return self.scheduler.get_last_lr()
class BaseExecutor:
def __init__(self, grad_accum_steps: int = 1):
self.gradient_state = GradientState(grad_accum_steps)
def prepare(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
dataloader: Optional[DataLoader] = None,
scheduler: Optional[LRScheduler] = None,
) -> Tuple[
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
]:
model = self._prepare_model(model)
if optimizer is not None:
optimizer = AccumOptimizer(optimizer, self.gradient_state)
if scheduler is not None:
scheduler = AccumScheduler(scheduler, self.gradient_state)
return model, optimizer, dataloader, scheduler
def _prepare_model(self, model: nn.Module) -> nn.Module:
return model
def _no_sync(self, model: nn.Module):
return contextlib.nullcontext()
@contextmanager
def accumulate(self, model: nn.Module):
self.gradient_state._do_sync()
if not self.gradient_state.sync_gradients:
with self._no_sync(model):
yield
else:
yield
def backward(self, loss: torch.Tensor):
loss.backward()
def unwrap_model(self, model: nn.Module):
return model.state_dict()
@property
def use_distributed(self) -> bool:
return get_world_size() > 1
@property
def sync_gradients(self) -> bool:
return self.gradient_state.sync_gradients
@property
def grad_accum_steps(self) -> int:
return self.gradient_state.num_steps
class ExecutorFactory(BaseFactory[BaseExecutor]):
pass
@ExecutorFactory.register("none")
class NoneExecutor(BaseExecutor):
pass
@ExecutorFactory.register("ddp")
class DDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
dim: int = 0,
broadcast_buffers: bool = True,
init_sync: bool = True,
process_group=None,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
delay_all_reduce_named_params=None,
param_to_hook_all_reduce=None,
mixed_precision=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._ddp_kwargs = dict(
dim=dim,
broadcast_buffers=broadcast_buffers,
init_sync=init_sync,
process_group=process_group,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
delay_all_reduce_named_params=delay_all_reduce_named_params,
param_to_hook_all_reduce=param_to_hook_all_reduce,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
)
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("DDP backend selected but world_size=1, model not wrapped")
return model
local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
**self._ddp_kwargs,
)
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, DDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=True,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
return model
self._original_model = model
device_id = torch.device("cuda", get_rank())
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, FSDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, FSDP) and self.use_distributed:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
return model.state_dict()
return model.state_dict()

View File

@ -1,8 +1,7 @@
import os import os
from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Callable from typing import Callable, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -31,11 +30,11 @@ def get_rank() -> int:
def setup_parallel( def setup_parallel(
rank: int, rank: int,
world_size: int, world_size: int,
local_rank: int,
backend: str = "nccl", backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
): ):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -43,18 +42,19 @@ def setup_parallel(
return return
if world_size <= 1: if world_size <= 1:
device_id = torch.device(device_type, local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
yield None yield None
return return
device_id = torch.device(device_type, local_rank) if device_ids is None:
device_ids = [i for i in range(world_size)]
rank = device_ids[rank % len(device_ids)]
device_id = torch.device(device_type, device_ids[rank])
os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
@ -96,120 +96,34 @@ def only_on_rank(rank, sync=False):
return decorator return decorator
def _run_single_rank( def wrapper_spawn_func(
rank: int, rank: int,
world_size: int, world_size: int,
backend: str, backend: str,
master_addr: str, master_addr: str,
master_port: str, master_port: str,
device_type: str, device_type: str,
device_ids: List[int],
func: Callable, func: Callable,
kwargs: dict, kwargs: dict,
): ):
try:
with setup_parallel( with setup_parallel(
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
local_rank=rank,
backend=backend, backend=backend,
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
device_type=device_type, device_type=device_type,
device_ids=device_ids,
): ):
func(**kwargs) func(**kwargs)
except Exception as e:
class LaunchStrategy(ABC): print(f"Error in rank {rank}: {e}")
"""Strategy for launching a function in a distributed context."""
def __init__(
self,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
device_type: str,
start_method: str,
):
self.world_size = world_size
self.backend = backend
self.master_addr = master_addr
self.master_port = master_port
self.device_type = device_type
self.start_method = start_method
@abstractmethod
def launch(self, func: Callable, **kwargs):
raise NotImplementedError
class TorchrunStrategy(LaunchStrategy):
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
def launch(self, func: Callable, **kwargs):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
with setup_parallel(
rank=rank,
world_size=world_size,
local_rank=local_rank,
backend=self.backend,
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
master_port=os.environ.get("MASTER_PORT", self.master_port),
device_type=self.device_type,
):
func(**kwargs)
class LocalStrategy(LaunchStrategy):
"""Local launcher — single-process or mp.start_processes."""
def launch(self, func: Callable, **kwargs):
args = (
self.world_size,
self.backend,
self.master_addr,
self.master_port,
self.device_type,
func,
kwargs,
)
if self.world_size == 1:
_run_single_rank(0, *args)
return
ctx = mp.start_processes(
_run_single_rank,
args=args,
nprocs=self.world_size,
start_method=self.start_method,
join=False,
)
try:
while not ctx.join():
pass
except BaseException:
for p in ctx.processes:
p.terminate()
ctx.join()
raise raise
def _detect_launcher() -> str:
"""Detect the distributed launcher from environment.
Returns one of: "torchelastic", "torchrun", "external", "local".
"""
if dist.is_torchelastic_launched():
return "torchelastic"
if "LOCAL_WORLD_SIZE" in os.environ:
return "torchrun"
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
return "external"
return "local"
def spawn_parallel_fn( def spawn_parallel_fn(
func: Callable, func: Callable,
world_size: int, world_size: int,
@ -217,16 +131,40 @@ def spawn_parallel_fn(
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
start_method: str = "spawn", device_ids: Optional[List[int]] = None,
**kwargs, **kwargs,
): ):
launcher = _detect_launcher() # clear environment variables
if launcher in ("torchelastic", "torchrun", "external"): for key in [
strategy = TorchrunStrategy( "MASTER_ADDR",
world_size, backend, master_addr, master_port, device_type, start_method "MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]
if world_size == 1:
device_ids = device_ids or [0]
device_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return
wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
device_ids,
func,
kwargs,
) )
else:
strategy = LocalStrategy( mp.spawn(
world_size, backend, master_addr, master_port, device_type, start_method wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
) )
strategy.launch(func, **kwargs)

View File

@ -1,14 +0,0 @@
from astrai.preprocessing.builder import (
BaseMaskBuilder,
MaskBuilderFactory,
SectionedMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
__all__ = [
"BaseMaskBuilder",
"MaskBuilderFactory",
"SectionedMaskBuilder",
"Pipeline",
"filter_by_length",
]

View File

@ -1,338 +0,0 @@
"""Mask building strategies for preprocessing pipeline.
The single :class:`SectionedMaskBuilder` handles all input formats
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
for single-output or ``input.sources`` for multi-output.
"""
from abc import ABC, abstractmethod
from typing import Optional
from astrai.factory import BaseFactory
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
Returns ``None`` to skip the item entirely.
"""
...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
val = item.get(domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
def _resolve_action(action: str, role: str, config) -> str:
"""Resolve action to "train" or "mask".
- ``"train"`` / ``"mask"`` literal
- ``"$role"`` look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
"""
if action == "$role":
return config.mask.get(role, config.mask_default)
return action
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder supporting single and multi-output modes.
Single-output (backward-compatible)::
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
{"sequence": [...], "loss_mask": [...], "domain": "..."}
Multi-output (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [
{"field": "chosen", "action": "$role", "template": true}
]},
"rejected": {"sections": [
{"field": "rejected", "action": "$role", "template": true}
]}
}}}
{"chosen": [...], "chosen_mask": [...],
"rejected": [...], "rejected_mask": [...], "domain": "..."}
Output spec fields::
sections list of section specs (same format as single-output)
list_field True when the JSONL field holds a list of values to
tokenise individually and concatenate (GRPO responses)
mask_key explicit output key for the loss mask
(default: ``"{output_key}_mask"``)
dtype explicit tensor dtype for this output key
(default: "int32")
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
sources_spec = getattr(config.input, "sources", None)
if sources_spec:
return self._build_multi(item, sources_spec, config, tokenizer)
return self._build_single(item, config, tokenizer)
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
sections = config.input.sections
if not sections:
return None
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
return None
result: dict = {
"sequence": ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in mask):
result["loss_mask"] = mask
return result
def _build_multi(
self, item: dict, sources_spec: dict, config, tokenizer
) -> Optional[dict]:
result: dict = {}
any_output = False
for output_key, spec in sources_spec.items():
sections = spec.get("sections", [])
if not sections:
continue
if self._is_value_section(sections):
ids = self._extract_raw_value(item, sections)
if ids is None:
continue
result[output_key] = ids
any_output = True
continue
list_field = spec.get("list_field", False)
mask_key = spec.get("mask_key", f"{output_key}_mask")
if list_field:
ids, mask = self._process_list_field(item, sections, config, tokenizer)
else:
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
continue
result[output_key] = ids
if not all(m == 1 for m in mask):
result[mask_key] = mask
elif "mask_key" in spec:
result[mask_key] = mask
any_output = True
if not any_output:
return None
result["domain"] = _extract_domain(item, config.output.domain_key)
return result
@staticmethod
def _is_value_section(sections: list) -> bool:
return len(sections) == 1 and sections[0].get("action") == "value"
@staticmethod
def _extract_raw_value(item: dict, sections: list):
"""Extract a raw value from a JSONL field without tokenisation.
Used for GRPO rewards where the field contains float values.
"""
sec = sections[0]
field = sec["field"]
raw = item.get(field)
if raw is None:
return None
if isinstance(raw, list):
return [float(v) for v in raw]
return [float(raw)]
def _process_sections(
self,
item: dict,
sections: list,
config,
tokenizer,
*,
is_top_level: bool = False,
):
"""Process a list of sections into ``(ids, loss_mask)``.
Returns ``(None, None)`` if the item should be skipped.
"""
all_ids: list[int] = []
loss_mask: list[int] = []
has_template = any(s.get("template") for s in sections)
is_text_config = not has_template and all(
s["action"] == "train" for s in sections
)
if is_top_level and has_template and tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
loss_mask.append(0)
first_section = True
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
add_special = sec.get(
"add_special_tokens", not use_template and first_section
)
if use_template:
success = self._append_template_section(
item, field, action, tokenizer, config, all_ids, loss_mask
)
if not success:
continue
else:
success = self._append_text_section(
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
)
if not success:
continue
first_section = False
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
if is_top_level and has_template and len(all_ids) <= 1:
return None, None
return all_ids, loss_mask
def _append_template_section(
self, item, field, action, tokenizer, config, all_ids, loss_mask
):
messages = item.get(field)
if not isinstance(messages, list) or not messages:
return False
for msg in messages:
role = msg.get("role", "")
act = _resolve_action(action, role, config)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
all_ids.extend(ids)
val = 1 if act == "train" else 0
loss_mask.extend([val] * len(ids))
return True
def _append_text_section(
self,
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
):
text = str(item.get(field, ""))
if not text.strip():
return False
if is_text_config:
pp = config.preprocessing
if pp.min_chars > 0 and len(text) < pp.min_chars:
return False
if len(text) > pp.max_chars:
return False
ids = tokenizer.encode(text, add_special_tokens=add_special)
all_ids.extend(ids)
val = 1 if action == "train" else 0
loss_mask.extend([val] * len(ids))
return True
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
all_ids: list[int] = []
loss_mask: list[int] = []
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
values = item.get(field)
if not isinstance(values, list):
continue
for val in values:
if use_template:
if isinstance(val, list):
wrapper = {field: val}
self._append_template_section(
wrapper,
field,
action,
tokenizer,
config,
all_ids,
loss_mask,
)
else:
wrapper = {field: str(val)}
self._append_text_section(
wrapper,
field,
action,
tokenizer,
False,
False,
config,
all_ids,
loss_mask,
)
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
return all_ids, loss_mask

View File

@ -1,257 +0,0 @@
"""Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
sharding and flush to ``.h5`` / ``.bin`` storage.
"""
import json
import os
from collections import defaultdict
from itertools import chain
from typing import List, Optional, Tuple
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import SectionedMaskBuilder
from astrai.tokenize import AutoTokenizer
_STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
def _truncate(seq: list, max_len: int, mode: str) -> list:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
def pack_sequences(
sequences: List[list],
max_packed_len: int,
strategy: str,
truncation_mode: str,
) -> List[Tuple[int, int]]:
"""Pack *sequences* into bins and return a reorder plan.
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
All keys (sequence, loss_mask, ) must be reordered and truncated
identically according to this plan.
Supported *strategy* values:
- ``"simple"``: sequential, no reordering.
- ``"bfd"``: best-fit decreasing bin packing.
"""
n = len(sequences)
if strategy == "simple":
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
Usage::
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
"""
def __init__(
self,
config: PipelineConfig,
input_paths: list[str],
output_dir: str,
tokenizer_path: str,
):
os.makedirs(output_dir, exist_ok=True)
self.config = config
self.paths = input_paths
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = SectionedMaskBuilder()
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
domains: dict = defaultdict(lambda: defaultdict(list))
total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int)
count = 0
pp = self.config.preprocessing
for item in tqdm.tqdm(
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
):
if pp.max_items and count >= pp.max_items:
break
result = self.transform(item)
if result is None:
continue
domain = result.pop("domain", "__default__")
is_multi = bool(getattr(self.config.input, "sources", None))
if is_multi:
ids = self._primary_ids(result)
else:
ids = result.pop("sequence")
result["sequence"] = ids
if not ids:
continue
bucket = domains[domain]
self._align_bucket(bucket, result, ids, is_multi)
for key, val in result.items():
bucket[key].append(val)
count += 1
total_tokens += len(ids)
if total_tokens >= self.config.output.max_tokens_per_shard:
self._flush(domains, shard_idx)
domains.clear()
total_tokens = 0
if total_tokens > 0:
self._flush(domains, shard_idx)
print(f"Done. {count} documents tokenized.")
@staticmethod
def _primary_ids(result: dict) -> list:
"""Return the first list-valued entry in *result* as the primary id
sequence for token counting."""
for val in result.values():
if isinstance(val, list) and val and isinstance(val[0], int):
return val
return []
@staticmethod
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
"""Pad previously-accumulated keys that are missing from *result*."""
for key in list(bucket.keys()):
if key in result:
continue
if is_multi:
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
bucket[key].append(pad)
else:
bucket[key].append([1] * len(ids))
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
chunk_dir = os.path.join(self.output_dir, domain)
pp = self.config.preprocessing
if pp.packing_strategy != "simple" and "sequence" in keys:
plan = pack_sequences(
keys["sequence"],
pp.max_packed_len,
pp.packing_strategy,
pp.truncation_mode,
)
reordered = defaultdict(list)
for orig_idx, truncated_len in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
)
)
keys = reordered
tensors = {}
for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get(
self.config.output.dtype.get(key, "int32"), torch.int32
)
tensors[key] = [
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
]
pid_mode = self.config.output.position_ids_mode
if pid_mode and pid_mode != "none" and "sequence" in tensors:
pos_ids = []
if pid_mode == "doc_reset":
for item in keys["sequence"]:
pos_ids.extend(range(len(item)))
else:
total = sum(len(item) for item in keys["sequence"])
pos_ids = list(range(total))
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
fmt = self.config.output.storage_format
if fmt == "bin":
save_bin(shard_path, tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors[first_key][0].numel():,} tokens)"
)

View File

@ -1,21 +0,0 @@
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class OptimizerProtocol(Protocol):
def step(self, closure=None): ...
def zero_grad(self): ...
@property
def param_groups(self) -> Any: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
@runtime_checkable
class SchedulerProtocol(Protocol):
def step(self): ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
def get_last_lr(self): ...

View File

@ -1,182 +1,106 @@
import io
import json import json
import time import os
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Union from typing import Any, Dict, List
import h5py
import safetensors.torch as st import safetensors.torch as st
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_CONFIG_FILE = "config.json" def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
_WEIGHTS_FILE = "model.safetensors" os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def save_safetensors(state_dict: dict, path: Union[str, Path]): def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
st.save_file(state_dict, str(path)) tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: Union[str, Path]):
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
with open(str(path), "r") as f:
return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: Union[str, Path]):
torch.save(obj, str(path))
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
path = Path(path)
if not broadcast or not dist.is_initialized():
return load_safetensors(path)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(path)
specs = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass
class Checkpoint: class Checkpoint:
state_dict: Dict[str, Any] = field(default_factory=dict) def __init__(
epoch: int = 0 self,
iteration: int = 0 state_dict: Dict[str, Any],
extra: Dict[str, Any] = field(default_factory=dict) epoch: int = 0,
meta: Dict[str, Any] = field(default_factory=dict) iteration: int = 0,
config: Dict[str, Any] = field(default_factory=dict) ):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
def save(
self,
save_dir: str,
) -> None:
def save(self, save_dir: str):
save_path = Path(save_dir) save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
if get_rank() != 0: rank = get_rank()
return if rank == 0:
meta = { meta = {
"epoch": self.epoch, "epoch": self.epoch,
"iteration": self.iteration, "iteration": self.iteration,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
**self.meta,
} }
save_json(meta, save_path / _META_FILE) with open(save_path / "meta.json", "w") as f:
save_json(self.config, save_path / _CONFIG_FILE) json.dump(meta, f, indent=2)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
for key, value in self.extra.items(): st.save_file(self.state_dict, save_path / "state_dict.safetensors")
save_torch(value, save_path / f"{key}.pt")
@classmethod @classmethod
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": def load(
cls,
save_dir: str,
) -> "Checkpoint":
rank = get_rank()
save_path = Path(save_dir) save_path = Path(save_dir)
meta = load_json(save_path / _META_FILE, broadcast) meta = {}
config = load_json(save_path / _CONFIG_FILE, broadcast) if rank == 0:
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast) with open(Path(save_dir) / "meta.json", "r") as f:
meta = json.load(f)
extra = {} if dist.is_initialized():
for f in sorted(save_path.iterdir()): meta_list = [meta]
if f.suffix == ".pt": dist.broadcast_object_list(meta_list, src=0)
extra[f.stem] = load_torch(f, broadcast=broadcast) meta = meta_list[0]
state_dict = st.load_file(save_path / "state_dict.safetensors")
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
epoch=meta.get("epoch", 0), epoch=meta["epoch"],
iteration=meta.get("iteration", 0), iteration=meta["iteration"],
extra=extra,
config=config,
) )

View File

@ -1,10 +1,13 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from jinja2 import Template from jinja2 import Template
# Message type for chat messages
type MessageType = Dict[str, Any] type MessageType = Dict[str, Any]
@dataclass
class ChatTemplate: class ChatTemplate:
"""A chat template with Jinja2 rendering support. """A chat template with Jinja2 rendering support.
@ -12,24 +15,23 @@ class ChatTemplate:
name: Unique identifier for the template. name: Unique identifier for the template.
template_str: Jinja2 template string. template_str: Jinja2 template string.
description: Optional description. description: Optional description.
default_variables: Optional dictionary of default variable values. default_variables: Optional dictionary of default variable values
that will be passed to the template if not overridden during rendering.
special_tokens: Optional dictionary mapping token names to their string values. special_tokens: Optional dictionary mapping token names to their string values.
These tokens are automatically added to the template variables.
""" """
def __init__( name: str
self, template_str: str
name: str = "", description: str = ""
template_str: str = "", default_variables: Dict[str, Any] = None
description: str = "", special_tokens: Dict[str, str] = None
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None, def __post_init__(self):
): if self.default_variables is None:
self.name = name self.default_variables = {}
self.template_str = template_str if self.special_tokens is None:
self.description = description self.special_tokens = {}
self.default_variables = default_variables or {}
self.special_tokens = special_tokens or {}
self._compiled: Template = Template(template_str)
@classmethod @classmethod
def from_string( def from_string(
@ -41,7 +43,7 @@ class ChatTemplate:
) -> "ChatTemplate": ) -> "ChatTemplate":
"""Create a ChatTemplate instance directly from a template string.""" """Create a ChatTemplate instance directly from a template string."""
return cls( return cls(
name="", name="", # empty name for adhoc templates
template_str=template_str, template_str=template_str,
description=description, description=description,
default_variables=default_variables, default_variables=default_variables,
@ -71,4 +73,5 @@ class ChatTemplate:
if system_prompt is not None: if system_prompt is not None:
variables["system_prompt"] = system_prompt variables["system_prompt"] = system_prompt
return self._compiled.render(**variables) jinja_template = Template(self.template_str)
return jinja_template.render(**variables)

View File

@ -51,26 +51,9 @@ class AutoTokenizer:
self.set_chat_template(config["chat_template"]) self.set_chat_template(config["chat_template"])
@classmethod @classmethod
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer": def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory. """Load tokenizer from pretrained directory."""
Raises:
FileNotFoundError: If tokenizer.json is missing.
RuntimeError: If tokenizer failed to initialize.
"""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
if not tokenizer_file.exists():
raise FileNotFoundError(
f"Tokenizer file not found: {tokenizer_file}. "
"A valid tokenizer.json is required."
)
instance = cls(path) instance = cls(path)
if instance._tokenizer is None:
raise RuntimeError(
f"Failed to load tokenizer from {path}. "
"The tokenizer.json may be corrupted or incompatible."
)
return instance return instance
def save_pretrained(self, save_path: str): def save_pretrained(self, save_path: str):
@ -81,11 +64,6 @@ class AutoTokenizer:
save_path: Path to save the tokenizer save_path: Path to save the tokenizer
""" """
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
save_path = Path(save_path) save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)

View File

@ -1,4 +1,3 @@
from astrai.trainer.optim import Muon
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import ( from astrai.trainer.train_callback import (
@ -10,8 +9,6 @@ from astrai.trainer.trainer import Trainer
__all__ = [ __all__ = [
# Main trainer # Main trainer
"Trainer", "Trainer",
# Optimizer
"Muon",
# Strategy factory # Strategy factory
"StrategyFactory", "StrategyFactory",
"BaseStrategy", "BaseStrategy",

View File

@ -1,42 +1,75 @@
from typing import Any, Callable, Dict from typing import Dict
import torch
import torch.nn as nn import torch.nn as nn
def _grad_stat(
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
) -> dict:
results = {}
for name, param in model.named_parameters():
results[name] = default
if param.grad is not None:
results[name] = fn(param.grad.data)
return results
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0) """Compute gradient norm for each parameter in the model."""
norms = {}
for name, param in model.named_parameters():
norms[name] = 0.0
if param.grad:
norm = param.grad.data.norm(norm_type).item()
norms[name] = norm
return norms
def grad_std(model: nn.Module) -> Dict[str, float]: def grad_std(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.std().item(), 0.0) """Compute standard deviation of gradients for each parameter."""
stds = {}
for name, param in model.named_parameters():
stds[name] = 0.0
if param.grad:
std = param.grad.data.std().item()
stds[name] = std
return stds
def grad_max(model: nn.Module) -> Dict[str, float]: def grad_max(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.max().item(), -float("inf")) """Find the maximum absolute gradient value for each parameter."""
max_vals = {}
for name, param in model.named_parameters():
max_vals[name] = -float("inf")
if param.grad:
max_val = param.grad.data.max().item()
max_vals[name] = max_val
return max_vals
def grad_min(model: nn.Module) -> Dict[str, float]: def grad_min(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.min().item(), float("inf")) """Find the minimum absolute gradient value for each parameter."""
min_vals = {}
for name, param in model.named_parameters():
min_vals[name] = float("inf")
if param.grad:
min_val = param.grad.data.min().item()
min_vals[name] = min_val
return min_vals
def grad_mean(model: nn.Module) -> Dict[str, float]: def grad_mean(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.mean().item(), 0.0) """Compute mean of gradients for each parameter."""
means = {}
for name, param in model.named_parameters():
means[name] = 0.0
if param.grad:
mean = param.grad.data.mean().item()
means[name] = mean
return means
def grad_nan_num(model: nn.Module) -> Dict[str, int]: def grad_nan_num(model: nn.Module) -> Dict[str, int]:
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0) """Count the number of NaNs in gradients for each parameter."""
nan_nums = {}
for name, param in model.named_parameters():
nan_nums[name] = 0
if param.grad:
nan_num = param.grad.isnan().sum().item()
nan_nums[name] = nan_num
return nan_nums
def ctx_get_loss(ctx): def ctx_get_loss(ctx):
@ -47,10 +80,6 @@ def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]["lr"] return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_val_loss(ctx):
return ctx.val_loss
def ctx_get_grad_norm(ctx): def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model) return grad_norm(ctx.model)

View File

@ -1,143 +0,0 @@
import torch
from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2
X = G
scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale
if steps == 0:
return X
a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * (A @ B)
return X
class Muon(Optimizer):
def __init__(
self,
params,
lr: float = 2e-3,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
ns_steps: int = 5,
adamw_lr: float = None,
adamw_betas: tuple = (0.9, 0.95),
adamw_eps: float = 1e-8,
adamw_wd: float = 0.0,
):
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
ns_steps=ns_steps,
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
adamw_betas=adamw_betas,
adamw_eps=adamw_eps,
adamw_wd=adamw_wd,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_2d, params_1d = [], []
grads_2d, grads_1d = [], []
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2:
params_2d.append(p)
grads_2d.append(p.grad)
else:
params_1d.append(p)
grads_1d.append(p.grad)
if params_2d:
self._muon_update_foreach(params_2d, grads_2d, group)
if params_1d:
self._adamw_update_foreach(params_1d, grads_1d, group)
return loss
def _muon_update_foreach(self, params_2d, grads_2d, group):
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
if wd != 0:
torch._foreach_mul_(params_2d, 1 - lr * wd)
if nesterov:
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
bufs = []
for p, grad in zip(params_2d, grads_2d):
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
bufs.append(state["momentum_buffer"])
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
for p, buf in zip(params_2d, bufs):
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update_foreach(self, params_1d, grads_1d, group):
lr = group["adamw_lr"]
betas = group["adamw_betas"]
eps = group["adamw_eps"]
wd = group["adamw_wd"]
steps: list[int] = []
exp_avgs, exp_avg_sqs = [], []
has_state = []
for p in params_1d:
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
has_state.append(False)
else:
has_state.append(True)
state["step"] += 1
steps.append(state["step"])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
beta1, beta2 = betas
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
bias_correction1 = [1 - beta1**s for s in steps]
bias_correction2 = [1 - beta2**s for s in steps]
if wd != 0:
torch._foreach_mul_(params_1d, 1 - lr * wd)
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
denom = torch._foreach_sqrt(denom)
torch._foreach_add_(denom, eps)
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
""" """
@classmethod @classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]): def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Validate that the scheduler class inherits from BaseScheduler.""" """Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler): if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -1,5 +1,6 @@
"""Training strategy implementations with factory pattern.""" """Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
@ -7,14 +8,26 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
def create_ref_model(model_fn, state_dict: dict) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict.""" """Unwrap DDP wrapper if present to get the original model."""
ref_model = model_fn() if isinstance(model, DDP):
ref_model.load_state_dict(state_dict) return model.module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
ref_model.requires_grad_(False) ref_model.requires_grad_(False)
ref_model.eval() ref_model.eval()
return ref_model return ref_model
@ -68,22 +81,6 @@ def get_logprobs(
return token_logprobs * shifted_mask return token_logprobs * shifted_mask
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
S = position_ids.size(1)
device = position_ids.device
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
doc_ids = torch.cat(
[
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
boundaries.long().cumsum(dim=1),
],
dim=1,
)
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
return (same_doc & causal).unsqueeze(1)
class BaseStrategy(ABC): class BaseStrategy(ABC):
"""Abstract base class for training strategies.""" """Abstract base class for training strategies."""
@ -92,8 +89,6 @@ class BaseStrategy(ABC):
): ):
self.model = model self.model = model
self.device = device self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs self.extra_kwargs = kwargs
@abstractmethod @abstractmethod
@ -128,7 +123,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
""" """
@classmethod @classmethod
def _validate_component(cls, strategy_cls: type): def _validate_component(cls, strategy_cls: type) -> None:
"""Validate that the strategy class inherits from BaseStrategy.""" """Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy): if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
@ -196,19 +191,15 @@ class SFTStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
input_ids, target_ids, position_ids, loss_mask = ( input_ids, target_ids, loss_mask = (
batch["input_ids"], batch["input_ids"],
batch["target_ids"], batch["target_ids"],
batch["position_ids"],
batch["loss_mask"], batch["loss_mask"],
) )
ignore_index = -100 ignore_index = -100
input_mask = make_doc_boundary_mask(position_ids) logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
logits = self.model(
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
)["logits"]
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),
@ -237,9 +228,7 @@ class DPOStrategy(BaseStrategy):
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model( self.ref_model = create_ref_model(model)
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
@ -276,9 +265,7 @@ class DPOStrategy(BaseStrategy):
class GRPOStrategy(BaseStrategy): class GRPOStrategy(BaseStrategy):
"""Group Relative Policy Optimization strategy. """Group Relative Policy Optimization strategy.
On-policy GRPO following DeepSeek-R1: the policy model is updated while Implements GRPO with clipping and KL penalty.
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
""" """
def __init__( def __init__(
@ -289,29 +276,16 @@ class GRPOStrategy(BaseStrategy):
kl_coef: float = 0.01, kl_coef: float = 0.01,
group_size: int = 4, group_size: int = 4,
reduction: str = "mean", reduction: str = "mean",
sync_interval: int = 200,
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model( self.ref_model = create_ref_model(model)
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size
self.reduction = reduction self.reduction = reduction
self.sync_interval = sync_interval
self._step = 0
def sync_ref_model(self):
"""Copy current model weights to ref model."""
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1
if self._step % self.sync_interval == 0:
self.sync_ref_model()
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
prompts = batch["prompts"] prompts = batch["prompts"]
responses = batch["responses"] responses = batch["responses"]
@ -323,6 +297,7 @@ class GRPOStrategy(BaseStrategy):
masks_flat = masks.view(-1, response_len) masks_flat = masks.view(-1, response_len)
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
# Shape: (batch_size * group_size, seq_len + response_len)
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
@ -337,13 +312,14 @@ class GRPOStrategy(BaseStrategy):
) )
log_probs_ref = log_probs_ref.view(batch_size, group_size) log_probs_ref = log_probs_ref.view(batch_size, group_size)
# Compute advantages from rewards with normalization
eps = torch.finfo(log_probs_policy.dtype).eps eps = torch.finfo(log_probs_policy.dtype).eps
mean = rewards.mean(dim=-1, keepdim=True) mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + eps) advantages = (rewards - mean) / (std + eps)
ratio = torch.exp(log_probs_policy - log_probs_ref) # PPO-style clipped surrogate objective
ratio = torch.exp(0) # Off-policy: policy_model = old_model
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages

View File

@ -1,21 +1,15 @@
import json import json
import logging
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable from typing import Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from tqdm import tqdm from tqdm import tqdm
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device, get_rank
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_grad_max, ctx_get_grad_max,
@ -26,12 +20,9 @@ from astrai.trainer.metric_util import (
ctx_get_grad_std, ctx_get_grad_std,
ctx_get_loss, ctx_get_loss,
ctx_get_lr, ctx_get_lr,
ctx_get_val_loss,
) )
from astrai.trainer.train_context import TrainContext from astrai.trainer.train_context import TrainContext
logger = logging.getLogger(__name__)
@runtime_checkable @runtime_checkable
class TrainCallback(Protocol): class TrainCallback(Protocol):
@ -51,15 +42,18 @@ class TrainCallback(Protocol):
def on_epoch_end(self, context: TrainContext): def on_epoch_end(self, context: TrainContext):
"""Called at the end of each epoch.""" """Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext):
"""Called at the beginning of each step."""
def on_step_end(self, context: TrainContext):
"""Called at the end of each step."""
def on_batch_begin(self, context: TrainContext): def on_batch_begin(self, context: TrainContext):
"""Called at the beginning of each batch.""" """Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
"""Called at the end of each batch.""" """Called at the end of each batch."""
def on_optimizer_step(self, context: TrainContext):
"""Called on every optimizer step (sync step only)."""
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
"""Called when an error occurs during training.""" """Called when an error occurs during training."""
@ -75,6 +69,12 @@ class CallbackFactory(BaseFactory[TrainCallback]):
callback = CallbackFactory.create("my_callback", **kwargs) callback = CallbackFactory.create("my_callback", **kwargs)
""" """
@classmethod
def _validate_component(cls, callback_cls: type) -> None:
"""Validate that the callback class inherits from TrainCallback."""
if not issubclass(callback_cls, TrainCallback):
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
@CallbackFactory.register("gradient_clipping") @CallbackFactory.register("gradient_clipping")
class GradientClippingCallback(TrainCallback): class GradientClippingCallback(TrainCallback):
@ -85,43 +85,28 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float): def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
def on_optimizer_step(self, context: TrainContext): def on_step_begin(self, context: TrainContext):
_ = context
clip_grad_norm_(context.model.parameters(), self.max_grad_norm) clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@CallbackFactory.register("gradient_checkpointing") @CallbackFactory.register("scheduler")
class GradientCheckpointingCallback(TrainCallback): class SchedulerCallback(TrainCallback):
""" """
Activation checkpointing callback trades compute for memory Scheduler callback for trainer.
by recomputing specified module activations during the backward pass.
Args:
modules: Module types to apply checkpointing to.
""" """
def __init__(self, modules: Optional[List[type]] = None): def __init__(self):
self.modules = tuple(modules) if modules else () pass
def _enable(self, module: nn.Module):
if self.modules and isinstance(module, self.modules):
fn = module.forward
module._original_forward = fn
module.forward = lambda *a, **kw: torch_checkpoint(
fn, *a, use_reentrant=False, **kw
)
@staticmethod
def _disable(module: nn.Module):
if hasattr(module, "_original_forward"):
module.forward = module._original_forward
del module._original_forward
def on_train_begin(self, context: TrainContext): def on_train_begin(self, context: TrainContext):
context.model.apply(self._enable) for group in context.optimizer.param_groups:
logger.info("Gradient checkpointing enabled") if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
def on_train_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
context.model.apply(self._disable) if context.scheduler:
context.scheduler.step()
@CallbackFactory.register("checkpoint") @CallbackFactory.register("checkpoint")
@ -130,38 +115,36 @@ class CheckpointCallback(TrainCallback):
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
extra_keys = ("optimizer", "scheduler")
def __init__( def __init__(
self, self,
save_dir: str, save_dir: str,
interval: int, interval: int,
weight_only: bool = False, weight_only: bool = False,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.state_dict_fn = state_dict_fn
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
state_dict = context.executor.unwrap_model(context.model)
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join( save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}" self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
) )
extra = self.save_extra_fn(context) state_dict = (
context.checkpoint = Checkpoint( self.state_dict_fn(context.model)
state_dict=state_dict, if self.state_dict_fn
epoch=context.epoch, else context.model.state_dict()
iteration=context.iteration,
extra=extra,
config=context.model_config,
) )
context.checkpoint = Checkpoint(
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
)
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval: if context.iteration - self.last_ckpt_iter >= self.interval:
@ -174,15 +157,6 @@ class CheckpointCallback(TrainCallback):
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
self._save_checkpoint(context) self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@CallbackFactory.register("progress_bar") @CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):
@ -190,12 +164,8 @@ class ProgressBarCallback(TrainCallback):
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
def __init__( def __init__(self, num_epoch: int):
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
@only_on_rank(0) @only_on_rank(0)
@ -204,18 +174,16 @@ class ProgressBarCallback(TrainCallback):
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True, dynamic_ncols=True,
file=self.file or sys.stdout,
) )
@only_on_rank(0) @only_on_rank(0)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
postfix = { self.progress_bar.set_postfix(
{
"loss": f"{context.loss:.4f}", "loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
} }
if context.val_loss > 0: )
postfix["val_loss"] = f"{context.val_loss:.4f}"
self.progress_bar.set_postfix(postfix)
self.progress_bar.update(1) self.progress_bar.update(1)
@only_on_rank(0) @only_on_rank(0)
@ -247,7 +215,6 @@ class MetricLoggerCallback(TrainCallback):
self._metric_funcs = { self._metric_funcs = {
"loss": ctx_get_loss, "loss": ctx_get_loss,
"lr": ctx_get_lr, "lr": ctx_get_lr,
"val_loss": ctx_get_val_loss,
"grad_norm": ctx_get_grad_norm, "grad_norm": ctx_get_grad_norm,
"grad_std": ctx_get_grad_std, "grad_std": ctx_get_grad_std,
"grad_max": ctx_get_grad_max, "grad_max": ctx_get_grad_max,
@ -258,7 +225,7 @@ class MetricLoggerCallback(TrainCallback):
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { return {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.iteration, "iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics}, **{m: self._metric_funcs[m](context) for m in self.metrics},
@ -291,43 +258,3 @@ class MetricLoggerCallback(TrainCallback):
def on_error(self, context): def on_error(self, context):
self._save_log(context.epoch, context.iteration) self._save_log(context.epoch, context.iteration)
@CallbackFactory.register("validation")
class ValidationCallback(TrainCallback):
def _run_validation(self, context: TrainContext):
context.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in context.val_dataloader:
loss = context.strategy(batch)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
if context.world_size > 1 and dist.is_initialized():
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
avg_loss = loss_tensor.item()
context.val_loss = avg_loss
context.model.train()
step_count = context.iteration // context.config.grad_accum_steps
logger.info(
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_optimizer_step(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config
if cfg.val_step <= 0:
return
step_count = context.iteration // cfg.grad_accum_steps
if step_count % cfg.val_step == 0:
self._run_validation(context)

View File

@ -1,18 +1,15 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self from typing import Optional, Self
import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader, random_split from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler from astrai.dataset import ResumableDistributedSampler
from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.serialization import Checkpoint
from astrai.serialization import Checkpoint, load_json, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -21,18 +18,13 @@ class TrainContext:
model: nn.Module = field(default=None) model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None) strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None) dataloader: DataLoader = field(default=None)
optimizer: OptimizerProtocol = field(default=None) optimizer: Optimizer = field(default=None)
scheduler: SchedulerProtocol = field(default=None) scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
model_config: dict = field(default_factory=dict)
executor: BaseExecutor = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
val_dataloader: DataLoader = field(default=None)
val_loss: float = field(default=0.0)
world_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
@ -40,144 +32,68 @@ class TrainContext:
class TrainContextBuilder: class TrainContextBuilder:
def __init__( def __init__(self, config: TrainConfig):
self,
config: TrainConfig,
):
self.config = config self.config = config
self._resume_dir: Optional[str] = None self._context = TrainContext(
model=config.model,
world_size=get_world_size(),
rank=get_rank(),
)
def with_resume_dir(self, resume_dir: Optional[str]) -> Self: device = get_current_device()
self._resume_dir = resume_dir self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model)
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
state_dict=self._context.model.state_dict(),
)
else:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.model.load_state_dict(checkpoint.state_dict)
self._context.checkpoint = checkpoint
return self
def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset
config = self.config
sampler_offset = self._context.iteration * config.batch_size
resumeable_sampler = ResumableDistributedSampler(
data_source=config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset,
seed=config.random_seed,
)
dataloader = DataLoader(
config.dataset,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
train_type=self.config.strategy,
device=get_current_device(),
**self.config.extra_kwargs,
)
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
cfg = self.config return self._context
device = get_current_device()
executor = ExecutorFactory.create(
cfg.parallel_mode,
grad_accum_steps=cfg.grad_accum_steps,
**cfg.executor_kwargs,
)
model = cfg.model_fn()
model = model.to(device=device)
model_config = {}
if self._resume_dir:
config_path = Path(self._resume_dir) / "config.json"
if config_path.exists():
model_config = load_json(config_path)
if not model_config and hasattr(model, "config"):
model_config = model.config.to_dict()
context = TrainContext(
model=model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
model_config=model_config,
executor=executor,
)
if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict
if checkpoint.config:
context.model_config = checkpoint.config
else:
checkpoint = None
state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = cfg.start_epoch
context.iteration = cfg.start_batch
context.checkpoint = checkpoint
if cfg.lora is not None:
inject_lora(
model,
r=cfg.lora.r,
alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules),
)
context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
train_dataset = cfg.dataset
val_dataset = cfg.val_dataset
if val_dataset is None and cfg.val_split is not None:
n_total = len(cfg.dataset)
n_val = max(1, int(n_total * cfg.val_split))
n_train = n_total - n_val
generator = torch.Generator().manual_seed(cfg.random_seed)
train_dataset, val_dataset = random_split(
cfg.dataset, [n_train, n_val], generator=generator
)
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=train_dataset,
start_epoch=context.epoch,
start_iter=sampler_offset,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
train_dataset,
batch_size=cfg.batch_per_device,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
if val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
model=context.model,
train_type=cfg.strategy,
device=device,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs,
)
return context

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from astrai.config import TrainConfig from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import ( from astrai.trainer.train_callback import (
CallbackFactory, CallbackFactory,
TrainCallback, TrainCallback,
@ -24,28 +25,22 @@ class Trainer:
def _get_default_callbacks(self) -> List[TrainCallback]: def _get_default_callbacks(self) -> List[TrainCallback]:
cfg = self.train_config cfg = self.train_config
callbacks = [ return [
CallbackFactory.create(
"gradient_checkpointing",
modules=cfg.gradient_checkpointing_modules,
),
CallbackFactory.create(
"checkpoint",
cfg.ckpt_dir,
cfg.ckpt_interval,
),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch), CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"), CallbackFactory.create("scheduler"),
] ]
return callbacks
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
)
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks: for callback in self.callbacks:
@ -53,57 +48,55 @@ class Trainer:
if method: if method:
method(context) method(context)
def _trainer_loop(self, resume_dir: Optional[str] = None): def train(self, checkpoint: Optional[Checkpoint] = None):
context = ( config = self.train_config
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build() spawn_parallel_fn(
self._train_impl,
backend=config.backend,
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
device_type=config.device_type,
device_ids=config.device_ids,
checkpoint=checkpoint,
) )
executor = context.executor
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
try: try:
context.model.train() context.model.train()
# 1.epoch
for epoch in range(context.epoch, context.config.n_epoch): for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch context.epoch = epoch
self._call_callbacks("on_epoch_begin", context) self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader: for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context) if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step
with executor.accumulate(context.model): self._call_callbacks("on_step_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if executor.sync_gradients:
self._call_callbacks("on_optimizer_step", context)
context.optimizer.step() context.optimizer.step()
context.optimizer.zero_grad() context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
if context.scheduler: # 3. batch
context.scheduler.step() self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
context.iteration += 1
# to make the loss normalized by accumulation steps
stand_loss = loss / self.train_config.accumulation_steps
stand_loss.backward()
self._call_callbacks("on_batch_end", context)
self._call_callbacks("on_epoch_end", context) self._call_callbacks("on_epoch_end", context)
except Exception as e: except Exception as e:
logger.error("Training failed: %s", str(e), exc_info=True) logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks("on_error", context) self._call_callbacks("on_error", context)
raise raise
finally: finally:
self._call_callbacks("on_train_end", context) self._call_callbacks("on_train_end", context)
def train(self, resume_dir: Optional[str] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
resume_dir=resume_dir,
)

View File

@ -1,44 +0,0 @@
services:
server:
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
command: python -m scripts.tools.server --port 8000 --device cuda
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
restart: unless-stopped
server-cpu:
profiles: [cpu]
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
command: python -m scripts.tools.server --port 8000 --device cpu
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 120s
restart: unless-stopped

View File

@ -11,6 +11,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text(): def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT) model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
@ -21,15 +22,16 @@ def generate_text():
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
for token in engine.generate( response = engine.generate(
prompt=query, prompt=query,
stream=True, stream=False,
max_tokens=2048, max_tokens=2048,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50, top_k=50,
): )
print(token, end="", flush=True)
print(response)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,23 +24,12 @@ def batch_generate():
"请问什么是显卡", "请问什么是显卡",
] ]
prompts = [
tokenizer.apply_chat_template(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": q},
],
tokenize=False,
)
for q in inputs
]
engine = InferenceEngine( engine = InferenceEngine(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
responses = engine.generate( responses = engine.generate(
prompt=prompts, prompt=inputs,
stream=False, stream=False,
max_tokens=2048, max_tokens=2048,
temperature=0.8, temperature=0.8,

View File

@ -15,7 +15,7 @@ def chat():
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
messages = [{"role": "system", "content": "You are a helpful assistant."}] messages = []
engine = InferenceEngine(model=model, tokenizer=tokenizer) engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True: while True:

View File

@ -16,7 +16,6 @@ NC='\033[0m' # No Color
IMAGE_NAME="astrai" IMAGE_NAME="astrai"
IMAGE_TAG="latest" IMAGE_TAG="latest"
REGISTRY="" REGISTRY=""
CONTAINER_ID=""
# Print colored messages # Print colored messages
print_info() { print_info() {
@ -176,10 +175,6 @@ main() {
PORT="$2" PORT="$2"
shift 2 shift 2
;; ;;
--container)
CONTAINER_ID="$2"
shift 2
;;
--gpu) --gpu)
GPU=true GPU=true
shift shift
@ -202,7 +197,6 @@ main() {
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)" echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
echo " --context PATH Build context (default: .)" echo " --context PATH Build context (default: .)"
echo " --port PORT Port for run (default: 8000)" echo " --port PORT Port for run (default: 8000)"
echo " --container ID Container ID for logs"
echo " --gpu Enable GPU support" echo " --gpu Enable GPU support"
echo " --help Show this help message" echo " --help Show this help message"
echo "" echo ""
@ -211,7 +205,6 @@ main() {
echo " $0 build --tag v1.0.0" echo " $0 build --tag v1.0.0"
echo " $0 run --port 8080" echo " $0 run --port 8080"
echo " $0 run --gpu" echo " $0 run --gpu"
echo " $0 logs --container abc123"
echo " $0 push --registry ghcr.io/username" echo " $0 push --registry ghcr.io/username"
exit 0 exit 0
;; ;;
@ -244,7 +237,7 @@ main() {
show_info show_info
;; ;;
logs) logs)
show_logs "$CONTAINER_ID" show_logs "$2"
;; ;;
"") "")
print_error "No command specified. Use --help for usage" print_error "No command specified. Use --help for usage"

View File

@ -1,13 +1,9 @@
"""Benchmark AutoRegressiveLM with KVCache"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from astrai.config import AutoRegressiveLMConfig from astrai.model.transformer import ModelConfig, Transformer
from astrai.inference import KVCache
from astrai.model.transformer import AutoRegressiveLM
@dataclass @dataclass
@ -21,27 +17,29 @@ class BenchmarkResult:
class GenerationBenchmark: class GenerationBenchmark:
def __init__( def __init__(
self, self,
config: AutoRegressiveLMConfig, config: ModelConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.float16,
page_size: int = 128,
): ):
self.config = config self.config = config
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype) self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size def _initialize_kv_cache(self, batch_size: int) -> list:
self._page_cache = KVCache( """初始化KV缓存"""
config = self.config
shape = (
batch_size,
config.max_len,
config.n_layers, config.n_layers,
n_pages,
page_size,
config.n_kv_heads, config.n_kv_heads,
head_dim, config.dim // config.n_heads,
device,
dtype,
) )
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint( prompt_ids = torch.randint(
@ -51,6 +49,7 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
gen_ids = torch.randint( gen_ids = torch.randint(
low=0, low=0,
high=self.config.vocab_size, high=self.config.vocab_size,
@ -58,6 +57,7 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
return prompt_ids, gen_ids return prompt_ids, gen_ids
@torch.inference_mode() @torch.inference_mode()
@ -67,11 +67,13 @@ class GenerationBenchmark:
prompt_length: int = 512, prompt_length: int = 512,
num_trials: int = 10, num_trials: int = 10,
) -> BenchmarkResult: ) -> BenchmarkResult:
for _ in range(3): for _ in range(3):
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
torch.cuda.synchronize() torch.cuda.synchronize()
total_time = 0.0 total_time = 0.0
@ -81,20 +83,20 @@ class GenerationBenchmark:
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
start = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
start.record() start_event.record()
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
end.record() end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start.elapsed_time(end) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
print( print(
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tok/s)" f"({prompt_length / trial_time:.1f} tokens/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -105,7 +107,7 @@ class GenerationBenchmark:
"benchmark_type": "prefill", "benchmark_type": "prefill",
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"dtype": str(self.dtype), "dtype": self.dtype,
"device": self.device, "device": self.device,
}, },
) )
@ -118,74 +120,41 @@ class GenerationBenchmark:
gen_length: int = 128, gen_length: int = 128,
num_trials: int = 5, num_trials: int = 5,
) -> BenchmarkResult: ) -> BenchmarkResult:
total_time = 0.0 total_time = 0.0
total_tokens = batch_size * gen_length * num_trials total_tokens = batch_size * gen_length * num_trials
page_size = self._page_cache.page_size
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs( prompt_ids, gen_ids = self._prepare_inputs(
batch_size, batch_size, prompt_length, prompt_length + gen_length
prompt_length,
prompt_length + gen_length,
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
total = n_pages * batch_size
pages = []
for _ in range(total):
p = self._page_cache._pool.alloc()
assert p >= 0, "OOM"
pages.append(p)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
device=self.device,
)
cv = self._page_cache.bind(page_table, total_len=prompt_length)
_ = self.model(
prompt_ids,
paged_cache=cv,
position_ids=torch.arange(
prompt_length, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_size, -1),
) )
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start.record()
current_pos = prompt_length current_pos = prompt_length
for i in range(gen_length): for i in range(gen_length):
input_token = gen_ids[:, i : i + 1] input_token = gen_ids[:, i : i + 1]
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
_ = self.model( _ = self.model(
input_token, input_token, persistent_key_values=kv_cache, start_pos=current_pos
paged_cache=cv,
position_ids=torch.full(
(batch_size, 1),
current_pos,
dtype=torch.long,
device=self.device,
),
) )
current_pos += 1 current_pos += 1
end.record()
end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start.elapsed_time(end) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
for idx in pages:
self._page_cache._pool.free(idx)
print( print(
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tok/s)" f"({gen_length / trial_time:.1f} tokens/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -197,26 +166,36 @@ class GenerationBenchmark:
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"gen_length": gen_length, "gen_length": gen_length,
"dtype": str(self.dtype), "dtype": self.dtype,
"device": self.device, "device": self.device,
}, },
) )
def print_benchmark_result(result: BenchmarkResult): def print_benchmark_result(result: BenchmarkResult):
btype = result.metadata["benchmark_type"] """打印基准测试结果"""
print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}") benchmark_type = result.metadata["benchmark_type"]
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s") print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tok/s") print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
for k, v in result.metadata.items():
if k != "benchmark_type": if benchmark_type == "prefill":
print(f"{k.replace('_', ' ').title()}: {v}") print(
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding":
print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80) print("-" * 80)
if __name__ == "__main__": if __name__ == "__main__":
config = AutoRegressiveLMConfig( config = ModelConfig(
vocab_size=10000, vocab_size=10000,
dim=1536, dim=1536,
n_heads=24, n_heads=24,
@ -230,20 +209,15 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running AutoRegressiveLM Generation Benchmark (KVCache)") print("Running Transformer Generation Benchmark")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, batch_size=4, prompt_length=512, num_trials=5
prompt_length=512,
num_trials=5,
) )
print_benchmark_result(prefill_result) print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark( gen_result = benchmark.run_decoding_benchmark(
batch_size=4, batch_size=4, prompt_length=512, gen_length=128, num_trials=5
prompt_length=512,
gen_length=128,
num_trials=5,
) )
print_benchmark_result(gen_result) print_benchmark_result(gen_result)

View File

@ -1,336 +0,0 @@
"""HumanEval code generation benchmark.
Generates n completions per problem, extracts function bodies, executes
against hidden tests, and computes pass@k.
Usage::
python scripts/tools/evaluate_humaneval.py --param_path ./params \
--data_path HumanEval.jsonl.gz --output results.json \
--num_samples 200 --temperature 0.8 --max_tokens 512
"""
import argparse
import json
import os
import re
import signal
import sys
from math import prod
from multiprocessing import Process, Queue
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import tqdm
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
HUMANEVAL_URL = (
"https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz"
)
_STOP_SEQUENCES = [
"\nclass ",
"\ndef ",
"\n# ",
"\nif __name__",
"\nprint(",
"\n\n\n",
]
def _download_humaneval(data_path: str):
if os.path.exists(data_path):
return
import gzip
import urllib.request
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
print(f"Downloading HumanEval from {HUMANEVAL_URL} ...")
tmp = data_path + ".tmp"
urllib.request.urlretrieve(HUMANEVAL_URL, tmp)
with gzip.open(tmp, "rb") as f_in:
with open(data_path, "wb") as f_out:
f_out.write(f_in.read())
os.remove(tmp)
print(f" saved to {data_path}")
def _load_problems(data_path: str) -> List[dict]:
problems = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
problems.append(json.loads(line))
return problems
def _extract_function_body(code: str, entry_point: str) -> Optional[str]:
"""Extract the function body from a completion."""
pattern = rf"def\s+{re.escape(entry_point)}\b[^:]*:"
match = re.search(pattern, code)
if not match:
# Use the full code as-is if we can't find the function
return code
body_start = match.end()
lines = code[body_start:].split("\n")
body_lines = []
started = False
for line in lines:
stripped = line.rstrip()
if not stripped and not started:
continue
if not stripped and started:
body_lines.append("")
continue
if not started:
started = True
if stripped.lstrip() == stripped and started:
break
body_lines.append(stripped)
body = "\n".join(body_lines)
if not body.strip():
return None
return body
def _trim_stop_sequences(text: str) -> str:
for stop in _STOP_SEQUENCES:
idx = text.find(stop)
if idx != -1:
text = text[:idx]
return text
def _execute_code(problem: dict, completion: str, timeout: float = 3.0) -> bool:
"""Run the completion against hidden tests in a subprocess."""
def _worker(queue, full_code):
try:
namespace = {}
exec(full_code, namespace)
check = namespace.get("check")
if check is None:
queue.put(False)
return
check(namespace.get(problem["entry_point"]))
queue.put(True)
except Exception:
queue.put(False)
full_code = problem["prompt"] + completion + "\n" + problem["test"]
queue: Queue = Queue()
proc = Process(target=_worker, args=(queue, full_code))
proc.start()
proc.join(timeout)
if proc.is_alive():
proc.terminate()
proc.join()
return False
try:
return queue.get_nowait()
except Exception:
return False
def _pass_at_k(n: int, c: int, k: int) -> float:
"""Unbiased estimator of pass@k."""
if n - c < k:
return 1.0
return 1.0 - float(prod(1.0 - k / np.arange(n - c + 1, n + 1)))
def _deduplicate(completions: List[str]) -> List[str]:
seen = set()
unique = []
for c in completions:
if c not in seen:
seen.add(c)
unique.append(c)
return unique
def _generate(
engine: InferenceEngine,
prompt: str,
num_samples: int,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
batch_size: int,
) -> List[str]:
batches = [prompt] * min(batch_size, num_samples)
completions = []
remaining = num_samples
while remaining > 0:
current = min(batch_size, remaining)
batch_prompts = batches[:current]
outputs = engine.generate(
prompt=batch_prompts,
stream=False,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
if isinstance(outputs, str):
outputs = [outputs]
completions.extend(outputs)
remaining -= current
return _deduplicate(completions)
def evaluate(
engine: InferenceEngine,
problems: List[dict],
num_samples: int,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
batch_size: int,
k_values: Tuple[int, ...] = (1, 10, 100),
) -> Dict:
results = {}
all_pass_at_k = {k: [] for k in k_values}
for problem in tqdm.tqdm(problems, desc="HumanEval", unit="problem"):
task_id = problem["task_id"]
prompt = problem["prompt"]
entry_point = problem["entry_point"]
raw_completions = _generate(
engine,
prompt,
num_samples,
max_tokens,
temperature,
top_p,
top_k,
batch_size,
)
completions = []
for raw in raw_completions:
trimmed = _trim_stop_sequences(raw)
body = _extract_function_body(trimmed, entry_point)
if body:
completions.append(body)
passed = 0
for comp in completions:
if _execute_code(problem, comp):
passed += 1
n = len(completions)
c = passed
result = {"task_id": task_id, "n": n, "passed": c}
for k in k_values:
result[f"pass@{k}"] = round(_pass_at_k(n, c, k), 4)
all_pass_at_k[k].append(_pass_at_k(n, c, k))
results[task_id] = result
summary = {}
for k in k_values:
vals = all_pass_at_k[k]
summary[f"pass@{k}"] = round(float(np.mean(vals)), 4)
results["_summary"] = summary
return results
def main():
parser = argparse.ArgumentParser(description="HumanEval benchmark")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_path",
type=str,
default="./humaneval/HumanEval.jsonl",
help="HumanEval JSONL file (auto-download if missing)",
)
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
parser.add_argument(
"--num_samples",
type=int,
default=200,
help="Completions per problem",
)
parser.add_argument(
"--max_tokens", type=int, default=512, help="Max generation tokens"
)
parser.add_argument(
"--temperature", type=float, default=0.8, help="Sampling temperature"
)
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
parser.add_argument(
"--batch_size", type=int, default=1, help="Inference batch size"
)
parser.add_argument(
"--problems",
type=int,
nargs="+",
default=None,
help="Specific problem indices (0-based)",
)
args = parser.parse_args()
_download_humaneval(args.data_path)
problems = _load_problems(args.data_path)
if args.problems:
problems = [problems[i] for i in args.problems if i < len(problems)]
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
model.to(device="cuda", dtype=torch.bfloat16)
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=args.batch_size,
)
results = evaluate(
engine=engine,
problems=problems,
num_samples=args.num_samples,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
batch_size=args.batch_size,
k_values=(1, 10, 100),
)
summary = results.pop("_summary")
print(f"\n{'=' * 60}")
for k, v in summary.items():
print(f" {k}: {v:.2%}")
print(f"{'=' * 60}")
if args.output:
results["_summary"] = summary
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Results saved to {args.output}")
engine.shutdown()
if __name__ == "__main__":
main()

View File

@ -1,319 +0,0 @@
"""MMLU evaluation via log-likelihood ranking."""
import argparse
import csv
import json
import os
import shutil
import tarfile
import requests
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _download_and_extract(url: str, data_dir: str):
tar_path = os.path.join(data_dir, "data.tar")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
resp = requests.get(url, stream=True, timeout=300)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
with open(tar_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
print("Extracting...")
with tarfile.open(tar_path, "r") as tf:
tf.extractall(data_dir)
os.remove(tar_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "data")
if os.path.exists(src):
for item in os.listdir(src):
src_item = os.path.join(src, item)
dst_item = os.path.join(data_dir, item)
if os.path.exists(dst_item):
if os.path.isdir(dst_item):
shutil.rmtree(dst_item)
else:
os.remove(dst_item)
os.rename(src_item, dst_item)
os.rmdir(src)
print(f"MMLU data saved to {data_dir}")
def _strip_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :].strip()
return text
def load_csv(path: str) -> list[dict]:
data = []
with open(path, "r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 6:
continue
if row[0].strip().lower() == "question":
continue
data.append(
{
"question": row[0].strip(),
"A": _strip_prefix(row[1].strip(), "A)"),
"B": _strip_prefix(row[2].strip(), "B)"),
"C": _strip_prefix(row[3].strip(), "C)"),
"D": _strip_prefix(row[4].strip(), "D)"),
"answer": row[5].strip(),
}
)
return data
def build_prompt(
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
) -> str:
prompt = ""
if n_shot > 0 and dev_data:
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
for item in dev_data[:n_shot]:
prompt += f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {item[k]}\n"
prompt += f"Answer: {item['answer']}\n\n"
prompt += f"Question: {question}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {choices[k]}\n"
prompt += "Answer:"
return prompt
def apply_chat(
tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None
) -> str:
"""Wrap raw MMLU prompt in the model's chat template format.
For few-shot, prepend example Q&A pairs as a second user/assistant exchange.
"""
messages = []
if n_shot > 0 and dev_data:
for item in dev_data[:n_shot]:
q = f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
q += f"{k}. {item[k]}\n"
q += "Answer:"
messages.append({"role": "user", "content": q})
messages.append({"role": "assistant", "content": item["answer"]})
messages.append({"role": "user", "content": raw_prompt})
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float:
choice_text = choice_letter
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids
max_len = model.config.max_len
if len(input_ids) > max_len:
overflow = len(input_ids) - max_len
input_ids = input_ids[overflow:]
ctx_len = len(input_ids) - len(choice_ids)
else:
ctx_len = len(context_ids)
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits = model(input_tensor)["logits"][0]
score = 0.0
for i, tid in enumerate(choice_ids):
pos = ctx_len - 1 + i
if pos >= len(logits):
break
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
return score
def evaluate_subject(
model,
tokenizer,
subject: str,
test_data: list[dict],
dev_data: list[dict] | None,
device: str,
n_shot: int,
) -> tuple[float, int, int]:
correct = 0
total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
raw_prompt = build_prompt(
item["question"], item, subject, n_shot, dev_data or []
)
context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or [])
context_ids = tokenizer.encode(context)
scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D")
}
if max(scores, key=scores.get) == item["answer"]:
correct += 1
total += 1
return correct / total, correct, total
def main():
parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
)
parser.add_argument("--download", action="store_true", help="Download MMLU data")
parser.add_argument(
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
)
parser.add_argument(
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
)
parser.add_argument("--output", type=str, help="Output JSON path")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16" if torch.cuda.is_available() else "float32",
help="Torch dtype",
)
args = parser.parse_args()
if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
model.eval()
subjects = args.subjects or MMLU_SUBJECTS
results = {}
total_correct = 0
total_questions = 0
for subject in subjects:
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
test_path = os.path.join(
args.data_dir, args.split, f"{subject}_{args.split}.csv"
)
if not os.path.exists(test_path):
print(f" Skipping {subject}: test file not found")
continue
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
test_data = load_csv(test_path)
acc, corr, tot = evaluate_subject(
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
)
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
total_correct += corr
total_questions += tot
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
overall = total_correct / total_questions if total_questions else 0
print(f"\n{'=' * 70}")
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
results["_overall"] = {
"accuracy": round(overall, 4),
"correct": total_correct,
"total": total_questions,
}
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@ -9,7 +9,7 @@ from astrai.tokenize import AutoTokenizer
def processor( def processor(
param_path: str, model_dir: str,
input_json_file: str, input_json_file: str,
output_json_file: str, output_json_file: str,
temperature: float, temperature: float,
@ -18,17 +18,14 @@ def processor(
question_key: str, question_key: str,
response_key: str, response_key: str,
max_tokens: int, max_tokens: int,
batch_size: int,
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(param_path) model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine # Create inference engine
engine = InferenceEngine( engine = InferenceEngine(model=model, tokenizer=tokenizer)
model=model, tokenizer=tokenizer, max_batch_size=batch_size
)
with open(input_json_file, "r", encoding="utf-8") as f: with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]
@ -75,7 +72,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument( parser.add_argument(
"--param_path", type=str, required=True, help="Path to the model directory." "--model_dir", type=str, required=True, help="Path to the model directory."
) )
parser.add_argument( parser.add_argument(
"--input_json_file", "--input_json_file",

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file( def process_file(
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(param_path) model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f: with open(input_file, "r", encoding="utf-8") as f:
@ -44,8 +44,8 @@ def process_file(
for seq in batch_encoded: for seq in batch_encoded:
pad_len = max_len - len(seq) pad_len = max_len - len(seq)
padded_seq = seq + [tokenizer.pad_id] * pad_len padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [True] * len(seq) + [False] * pad_len mask = [False] * pad_len + [True] * len(seq)
padded_ids.append(padded_seq) padded_ids.append(padded_seq)
masks.append(mask) masks.append(mask)
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument( parser.add_argument(
"--param_path", type=str, required=True, help="Path to the model directory." "--model_dir", type=str, required=True, help="Path to the model directory."
) )
parser.add_argument( parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file." "--input_file", type=str, required=True, help="Path to the input file."

View File

@ -1,38 +0,0 @@
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
import argparse
from astrai.config.preprocess_config import PipelineConfig
from astrai.preprocessing.pipeline import Pipeline
def main():
parser = argparse.ArgumentParser(
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
)
parser.add_argument(
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
)
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
parser.add_argument(
"--config", "-c", required=True, help="Path to pipeline config JSON"
)
parser.add_argument(
"--tokenizer_path",
default="params",
help="Path to tokenizer directory (default: params)",
)
args = parser.parse_args()
config = PipelineConfig.from_json(args.config)
Pipeline(
config=config,
input_paths=args.inputs,
output_dir=args.output_dir,
tokenizer_path=args.tokenizer_path,
).run()
if __name__ == "__main__":
main()

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch import torch
from astrai.inference import run_server from astrai.inference.server import run_server
def main(): def main():
@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development" "--reload", action="store_true", help="Enable auto-reload for development"
) )
parser.add_argument( parser.add_argument(
"--param_path", "--param-path",
type=Path, type=Path,
default=None, default=None,
help="Path to model parameters (default: project_root/params)", help="Path to model parameters (default: project_root/params)",

View File

@ -2,25 +2,28 @@ import argparse
import os import os
from functools import partial from functools import partial
import safetensors.torch as st
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import ModelConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import Transformer
from astrai.model.components.decoder_block import DecoderBlock from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.") parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument( parser.add_argument(
"--train_type", "--train_type",
type=str, type=str,
required=True, required=True,
choices=["seq", "sft", "dpo", "grpo"], choices=["seq", "sft", "dpo"],
help="Train type.", help="Train type.",
) )
parser.add_argument( parser.add_argument(
@ -40,19 +43,19 @@ def parse_args() -> argparse.Namespace:
"--n_epoch", type=int, default=1, help="Number of epochs to train." "--n_epoch", type=int, default=1, help="Number of epochs to train."
) )
parser.add_argument( parser.add_argument(
"--batch_per_device", type=int, default=1, help="Batch size per GPU." "--batch_size", type=int, default=1, help="Batch size for training."
) )
parser.add_argument( parser.add_argument(
"--grad_accum_steps", "--accumulation_steps",
type=int, type=int,
default=1, default=1,
help="Number of iterations between each optimizer step.", help="Number of iterations between each optimizer step.",
) )
parser.add_argument( parser.add_argument(
"--warmup_ratio", "--warmup_steps",
type=float, type=int,
default=0.05, default=1000,
help="Fraction of total steps used for LR warmup.", help="Number of iters between warnings.",
) )
parser.add_argument( parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training." "--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -67,13 +70,13 @@ def parse_args() -> argparse.Namespace:
"--adamw_beta1", "--adamw_beta1",
type=float, type=float,
default=0.9, default=0.9,
help="Beta1 for AdamW optimizer.", help="Beta values for AdamW optimizer.",
) )
parser.add_argument( parser.add_argument(
"--adamw_beta2", "--adamw_beta2",
type=float, type=float,
default=0.95, default=0.95,
help="Beta2 for AdamW optimizer.", help="Beta values for AdamW optimizer.",
) )
parser.add_argument( parser.add_argument(
"--adamw_weight_decay", "--adamw_weight_decay",
@ -97,31 +100,18 @@ def parse_args() -> argparse.Namespace:
"--window_size", "--window_size",
type=int, type=int,
default=None, default=None,
help="Max length of the input sequence.", help="the max length of the input sequence.",
) )
parser.add_argument( parser.add_argument(
"--stride", type=int, default=None, help="Step size of the input sequence." "--stride", type=int, default=None, help="the step size of the input sequence."
) )
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument(
"--grpo_clip_eps", type=float, default=0.2, help="GRPO clipping epsilon."
)
parser.add_argument(
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
)
parser.add_argument( parser.add_argument(
"--label_smoothing", "--label_smoothing",
type=float, type=float,
default=0.05, default=0.1,
help="cross_entropy function label smoothing parameter", help="cross_entropy function label smoothing parameter",
) )
parser.add_argument(
"--gradient_checkpointing",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable activation checkpointing for DecoderBlock modules.",
)
parser.add_argument( parser.add_argument(
"--ckpt_interval", "--ckpt_interval",
@ -135,42 +125,6 @@ def parse_args() -> argparse.Namespace:
default="checkpoint", default="checkpoint",
help="Directory to save checkpoints.", help="Directory to save checkpoints.",
) )
parser.add_argument(
"--val_split",
type=float,
default=None,
help="Ratio to split from training dataset for validation (e.g. 0.05).",
)
parser.add_argument(
"--val_step",
type=int,
default=1000,
help="Number of optimizer steps between validation runs.",
)
parser.add_argument(
"--metrics",
nargs="*",
default=["loss", "lr"],
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
)
parser.add_argument(
"--log_dir",
type=str,
default="checkpoint/logs",
help="Directory for metric logs.",
)
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="Number of batch iterations between metric logs.",
)
parser.add_argument(
"--grpo_sync_interval",
type=int,
default=200,
help="GRPO ref model sync interval (steps).",
)
parser.add_argument( parser.add_argument(
"--start_epoch", type=int, default=0, help="Start epoch for training." "--start_epoch", type=int, default=0, help="Start epoch for training."
) )
@ -178,54 +132,30 @@ def parse_args() -> argparse.Namespace:
"--start_batch", type=int, default=0, help="Start batch for training." "--start_batch", type=int, default=0, help="Start batch for training."
) )
parser.add_argument(
"--master_addr",
type=str,
default="localhost",
help="Master node address for distributed training.",
)
parser.add_argument(
"--master_port",
type=str,
default="29500",
help="Master node port for distributed training.",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
help="Distributed training backend.",
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp", "fsdp"],
help="Parallel training strategy (none, ddp, fsdp).",
)
parser.add_argument( parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use." "--device_type", type=str, default="cuda", help="Device type to use."
) )
parser.add_argument(
"--start_method",
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
def create_model(config): def ddp_wrap(model: nn.Module):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16) local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
ddp_model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=False,
)
return ddp_model
def create_optimizer(model, **kwargs) -> optim.Optimizer: def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs) return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler( def create_scheduler(
@ -234,21 +164,8 @@ def create_scheduler(
return SchedulerFactory.create(optimizer, **kwargs) return SchedulerFactory.create(optimizer, **kwargs)
def compute_total_steps( def prepare_checkpoint(model: nn.Module) -> dict:
dataset_len: int, return model.module.state_dict()
n_epoch: int,
batch_per_device: int,
nprocs: int,
grad_accum_steps: int,
) -> int:
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
return total_steps
def train( def train(
@ -257,23 +174,14 @@ def train(
data_root_path: str, data_root_path: str,
max_lr: float, max_lr: float,
n_epoch: int, n_epoch: int,
batch_per_device: int, batch_size: int,
start_epoch: int, start_epoch: int,
start_batch: int, start_batch: int,
grad_accum_steps: int, accumulation_steps: int,
warmup_ratio: float, warmup_steps: int,
ckpt_interval: int, ckpt_interval: int,
ckpt_dir: str, ckpt_dir: str,
val_split: float,
val_step: int,
metrics: list[str],
log_dir: str,
log_interval: int,
dpo_beta: float, dpo_beta: float,
grpo_clip_eps: float,
grpo_kl_coef: float,
group_size: int,
grpo_sync_interval: int,
adamw_beta1: float, adamw_beta1: float,
adamw_beta2: float, adamw_beta2: float,
adamw_weight_decay: float, adamw_weight_decay: float,
@ -282,44 +190,34 @@ def train(
random_seed: int, random_seed: int,
num_workers: int, num_workers: int,
pin_memory: bool, pin_memory: bool,
gradient_checkpointing: bool,
window_size: int, window_size: int,
stride: int, stride: int,
nprocs: int, nprocs: int,
parallel_mode: str,
device_type: str, device_type: str,
backend: str,
master_addr: str,
master_port: str,
start_method: str,
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config # Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")
config = AutoRegressiveLMConfig.from_file(config_path) if os.path.exists(config_path):
config.load(config_path)
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
strategy_kwargs = { # Create bare Transformer (for training, no tokenizer needed)
"beta": dpo_beta, model = Transformer(config)
"label_smoothing": label_smoothing,
"clip_eps": grpo_clip_eps,
"kl_coef": grpo_kl_coef,
"group_size": group_size,
"sync_interval": grpo_sync_interval,
}
executor_kwargs = { # Load weights if available
"gradient_as_bucket_view": True, weights_path = os.path.join(param_path, "model.safetensors")
"broadcast_buffers": False, if os.path.exists(weights_path):
} state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
model_fn = partial(create_model, config)
dataset = DatasetFactory.load( dataset = DatasetFactory.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
@ -336,58 +234,42 @@ def train(
}, },
) )
total_steps = compute_total_steps( total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
)
warmup_steps = int(warmup_ratio * total_steps)
scheduler_fn = partial( scheduler_fn = partial(
create_scheduler, create_scheduler,
**{ **{
"schedule_type": "cosine", "schedule_type": "cosine",
"warmup_steps": min(warmup_steps, total_steps), "warmup_steps": warmup_steps,
"lr_decay_steps": total_steps - min(warmup_steps, total_steps), "lr_decay_steps": total_steps - warmup_steps,
}, },
) )
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
train_config = TrainConfig( train_config = TrainConfig(
model_fn=model_fn, model=model,
strategy=train_type, strategy=train_type,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir, ckpt_dir=ckpt_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_per_device=batch_per_device, batch_size=batch_size,
start_epoch=start_epoch, start_epoch=start_epoch,
start_batch=start_batch, start_batch=start_batch,
ckpt_interval=ckpt_interval, ckpt_interval=ckpt_interval,
grad_accum_steps=grad_accum_steps, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
nprocs=nprocs, nprocs=nprocs,
backend=backend, parallel_wrapper=ddp_wrap,
master_addr=master_addr, state_dict_fn=prepare_checkpoint,
master_port=master_port,
parallel_mode=parallel_mode,
device_type=device_type, device_type=device_type,
start_method=start_method,
val_split=val_split,
val_step=val_step,
metrics=metrics,
log_dir=log_dir,
log_interval=log_interval,
gradient_checkpointing_modules=grad_ckpt_modules,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs, extra_kwargs=strategy_kwargs,
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train(resume_dir=param_path) trainer.train()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,22 +3,18 @@ import os
import shutil import shutil
import tempfile import tempfile
import numpy as np
import pytest import pytest
import safetensors.torch as st
import torch import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.model_config import AutoRegressiveLMConfig from astrai.config.model_config import ModelConfig
from astrai.model.transformer import AutoRegressiveLM from astrai.model.transformer import Transformer
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
def pytest_configure(config):
config.addinivalue_line("markers", "slow: marks tests as slow")
config.addinivalue_line("markers", "integration: integration tests")
config.addinivalue_line("markers", "unit: fast unit tests")
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer: def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes.""" """Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE()) tokenizer = Tokenizer(models.BPE())
@ -26,6 +22,7 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
trainer = trainers.BpeTrainer( trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"] vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
) )
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer) tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer() auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer auto_tokenizer._tokenizer = tokenizer
@ -37,7 +34,7 @@ class RandomDataset(Dataset):
"""Random dataset for testing purposes.""" """Random dataset for testing purposes."""
def __init__(self, length=None, max_length=64, vocab_size=1000): def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(torch.randint(100, 200, (1,)).item()) self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -55,7 +52,7 @@ class MultiTurnDataset(Dataset):
"""Multi-turn dataset with loss mask for SFT training tests.""" """Multi-turn dataset with loss mask for SFT training tests."""
def __init__(self, length=None, max_length=64, vocab_size=1000): def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(torch.randint(100, 200, (1,)).item()) self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -96,65 +93,46 @@ class EarlyStoppingDataset(Dataset):
} }
@pytest.fixture(scope="session") @pytest.fixture
def test_tokenizer(): def base_test_env(request: pytest.FixtureRequest):
"""Session-scoped tokenizer, created once for the entire test run.""" """Create base test environment with randomly configured model and tokenizer"""
return create_test_tokenizer() func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4]
@pytest.fixture(scope="session") dim = int(np.random.choice(n_dim_choices))
def test_model(): n_heads = int(np.random.choice(n_head_choices))
"""Session-scoped small AutoRegressiveLM model, created once.""" n_kv_heads = n_heads // 2
config = AutoRegressiveLMConfig( dim_ffn = dim * 2
vocab_size=1000,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoRegressiveLM(config).to(device=device)
return { config = {
"model": model, "vocab_size": 1000,
"device": device, "dim": dim,
"config": config, "n_heads": n_heads,
"n_kv_heads": n_kv_heads,
"dim_ffn": dim_ffn,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
} }
@pytest.fixture
def base_test_env(test_model, test_tokenizer):
"""Function-scoped test environment with isolated temp directory.
Composes session-scoped model and tokenizer with a per-test temp dir.
"""
test_dir = tempfile.mkdtemp()
config_path = os.path.join(test_dir, "config.json")
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump( json.dump(config, f)
{ device = "cuda" if torch.cuda.is_available() else "cpu"
"vocab_size": 1000, transformer_config = ModelConfig().load(config_path)
"dim": 8, model = Transformer(transformer_config).to(device=device)
"n_heads": 2, tokenizer = create_test_tokenizer()
"n_kv_heads": 1,
"dim_ffn": 16,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
},
f,
)
yield { yield {
"device": test_model["device"], "device": device,
"test_dir": str(test_dir), "test_dir": str(test_dir),
"config_path": config_path, "config_path": config_path,
"transformer_config": test_model["config"], "transformer_config": transformer_config,
"model": test_model["model"], "model": model,
"tokenizer": test_tokenizer, "tokenizer": tokenizer,
} }
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
@ -176,3 +154,43 @@ def multi_turn_dataset():
def early_stopping_dataset(): def early_stopping_dataset():
dataset = EarlyStoppingDataset() dataset = EarlyStoppingDataset()
yield dataset yield dataset
@pytest.fixture
def test_env(request: pytest.FixtureRequest):
"""Create a test environment with saved model and tokenizer files."""
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")
config = {
"vocab_size": 1000,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
}
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)
yield {
"test_dir": test_dir,
"model": model,
"tokenizer": tokenizer,
"transformer_config": transformer_config,
}
shutil.rmtree(test_dir)

View File

@ -1,202 +0,0 @@
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS_CONFIG = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
}
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
_INSTRUCTION_SECTIONS = [
{"field": "prompt", "action": "mask", "add_special_tokens": True},
{"field": "response", "action": "train"},
]
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
def _build_chat_tokenizer():
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
def make_chat_config():
return PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
def make_dpo_chat_config():
return PipelineConfig(
input=InputConfig(
sources={
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": True}
]
},
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": True}
]
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": True}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_no_template_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{
"field": "prompt",
"action": "mask",
"add_special_tokens": True,
}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)

View File

@ -1,4 +1,3 @@
import os
import tempfile import tempfile
import torch import torch
@ -36,30 +35,6 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.iteration == 30
def test_checkpoint_with_extra():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
extra = {
"optimizer": optimizer.state_dict(),
"scheduler": {"last_epoch": 5},
}
checkpoint = Checkpoint(
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
loaded = Checkpoint.load(tmpdir)
assert loaded.extra["scheduler"]["last_epoch"] == 5
assert "state" in loaded.extra["optimizer"]
def simple_training(): def simple_training():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)

View File

@ -1,18 +1,8 @@
import os
import numpy as np import numpy as np
import pytest
import torch import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory
from astrai.dataset.storage import ( from astrai.serialization import save_h5
H5Store,
StoreFactory,
detect_format,
load_bin,
save_bin,
save_h5,
)
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
@ -74,7 +64,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
) )
assert dpo_dataset is not None assert dpo_dataset is not None
assert dpo_dataset.storage is not None assert hasattr(dpo_dataset, "fetcher")
assert len(dpo_dataset) > 0 assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors # Test that we can get DPO items without errors
@ -98,7 +88,6 @@ def test_sft_dataset_with_random_data(base_test_env):
dummy_data = { dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)], "loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
} }
save_h5(test_dir, "sft_data", dummy_data) save_h5(test_dir, "sft_data", dummy_data)
@ -111,7 +100,7 @@ def test_sft_dataset_with_random_data(base_test_env):
) )
assert sft_dataset is not None assert sft_dataset is not None
assert sft_dataset.storage is not None assert hasattr(sft_dataset, "fetcher")
assert len(sft_dataset) > 0 assert len(sft_dataset) > 0
# Test that we can get SFT items without errors # Test that we can get SFT items without errors
@ -154,291 +143,3 @@ def test_dataset_with_custom_stride(base_test_env):
) )
assert len(dataset) > len(default_stride_dataset) assert len(dataset) > len(default_stride_dataset)
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
seq_length = 200
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
save_h5(test_dir, "count_test_data", dummy_data)
dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert dataset.count == seq_length
assert dataset.count > len(dataset) # raw tokens > windows
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
def test_empty_dataset_count():
"""Test count returns 0 when no data is loaded"""
dataset = SEQDataset(window_size=64, stride=32)
assert dataset.count == 0
assert dataset.keys == []
def test_dataset_too_short_for_window(base_test_env):
"""Dataset shorter than window_size returns __len__ == 0"""
test_dir = base_test_env["test_dir"]
seq_length = 30
save_h5(
test_dir,
"short",
{"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]},
)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) == 0
assert dataset.count == seq_length
def test_unloaded_dataset_getitem_raises():
"""__getitem__ without load() should fail clearly"""
dataset = SEQDataset(window_size=64, stride=32)
with pytest.raises(RuntimeError, match="not loaded"):
dataset.get_index(0)
def test_unloaded_dataset_len():
"""__len__ without load() returns 0"""
dataset = SEQDataset(window_size=64, stride=32)
assert len(dataset) == 0
def test_store_unloaded_len():
"""Unloaded Store has __len__ == 0"""
store = H5Store()
assert len(store) == 0
assert store.keys == []
def test_store_fetch_begin_equals_end(base_test_env):
"""Store.fetch with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
result = dataset.storage.fetch(10, 10, "sequence")
assert result.numel() == 0
def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError"""
store = H5Store()
with pytest.raises(RuntimeError, match="not loaded"):
store.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
"""detect_format raises FileNotFoundError for bad path"""
with pytest.raises(FileNotFoundError, match="No supported"):
detect_format("/nonexistent/path/xyz")
def test_detect_format_unsupported_file(base_test_env):
"""detect_format raises ValueError for unsupported file extension"""
test_dir = base_test_env["test_dir"]
path = os.path.join(test_dir, "data.txt")
with open(path, "w") as f:
f.write("hello")
with pytest.raises(ValueError, match="Unsupported"):
detect_format(path)
def test_create_store_invalid_type():
"""StoreFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StoreFactory.create("parquet")
def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
save_h5(data_dir, "data", {"sequence": segs})
store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7]
def test_save_load_bin_roundtrip(base_test_env):
"""save_bin + load_bin roundtrip preserves data"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
}
save_bin(test_dir, data)
result = load_bin(test_dir)
assert "sequence" in result
assert "loss_mask" in result
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
def test_mmap_store_load_and_fetch(base_test_env):
"""MmapStore loads bin data and fetches correctly"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
store = StoreFactory.create("bin")
store.load(test_dir)
assert len(store) == 200
assert "sequence" in store.keys
result = store.fetch(10, 20, "sequence")
assert result.tolist() == data["sequence"][0][10:20].tolist()
def test_mmap_dataset_load(base_test_env):
"""DatasetFactory.load auto-detects bin format"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) > 0
assert dataset.count == 200
assert dataset[0]["input_ids"].shape[0] == 64
def test_normalize_empty_key():
"""_normalize with empty tensor list does not crash"""
store = H5Store()
store._normalize({"sequence": []})
assert len(store) == 0
assert store.keys == ["sequence"]
def test_normalize_mixed_empty_key():
"""_normalize with empty + non-empty keys returns min=0"""
store = H5Store()
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
assert len(store) == 0
assert set(store.keys) == {"sequence", "loss_mask"}
def test_grpo_dataset_dtype(base_test_env):
"""GRPODataset returns correct dtypes"""
test_dir = base_test_env["test_dir"]
seq_len = 100
data = {
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"masks": [torch.ones(seq_len, dtype=torch.int32)],
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_dtype", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
item = dataset[0]
assert item["prompts"].dtype == torch.long
assert item["responses"].dtype == torch.long
assert item["masks"].dtype == torch.bool
assert item["rewards"].dtype == torch.float32
def test_grpo_dataset_load(base_test_env):
"""GRPODataset loads and returns correct keys"""
test_dir = base_test_env["test_dir"]
seq_len = 200
data = {
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"masks": [torch.ones(seq_len, dtype=torch.int64)],
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_test", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
assert len(dataset) > 0
item = dataset[0]
assert "prompts" in item
assert "responses" in item
assert "masks" in item
assert "rewards" in item
assert item["prompts"].shape[0] == 64
assert item["responses"].shape[0] == 64
def test_detect_format_bin_dir(base_test_env):
"""detect_format returns 'bin' for directory with .bin + meta.json"""
test_dir = base_test_env["test_dir"]
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
assert detect_format(test_dir) == "bin"
def test_store_fetch_multi_key(base_test_env):
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
test_dir = base_test_env["test_dir"]
save_h5(
test_dir,
"multi_key",
{
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
"loss_mask": [torch.ones(100, dtype=torch.int64)],
},
)
store = StoreFactory.create("h5")
store.load(test_dir)
result = store.fetch(10, 20, ["sequence", "loss_mask"])
assert isinstance(result, dict)
assert result["sequence"].shape[0] == 10
assert result["loss_mask"].shape[0] == 10
def test_store_fetch_out_of_bounds(base_test_env):
"""Store.fetch raises ValueError for out-of-bounds indices"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
store = StoreFactory.create("h5")
store.load(test_dir)
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(-1, 10, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(0, 51, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(50, 50, "sequence")
def test_dataset_load_explicit_storage_type(base_test_env):
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
assert len(dataset) > 0
assert dataset.count == 200

View File

@ -1,396 +0,0 @@
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
MaskBuilderFactory,
SectionedMaskBuilder,
)
from tests.data.conftest import (
_CHAT_SECTIONS,
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_chat_config,
make_dpo_chat_config,
make_grpo_config,
make_instruction_config,
make_text_config,
)
def test_chat_simple(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "sequence" in result
assert "loss_mask" in result
assert len(result["sequence"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["sequence"])
trained = sum(result["loss_mask"])
assert trained > 0
assert trained < total
def test_chat_mask_only_assistant(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
assert len(ids) == len(mask)
trained = [i for i, m in enumerate(mask) if m == 1]
masked = [i for i, m in enumerate(mask) if m == 0]
assert len(trained) > 0
assert len(masked) > 0
def test_chat_all_masked(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
def test_chat_empty_messages(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_chat_domain_extraction(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_chat_truncation(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, chat_tokenizer)
assert len(result["sequence"]) <= 10
assert len(result["loss_mask"]) == len(result["sequence"])
def test_instruction_basic(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
def test_instruction_prompt_masked(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_instruction_train_on_prompt(test_tokenizer):
config = PipelineConfig(
input=InputConfig(
sections=[
{"field": "prompt", "action": "train", "add_special_tokens": True},
{"field": "response", "action": "mask"},
]
),
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
def test_text_basic(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "sequence" in result
assert len(result["sequence"]) > 0
assert "loss_mask" not in result
def test_text_empty(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_text_truncation(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["sequence"]) <= 3
def test_sectioned_chat(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"]
def test_sectioned_instruction(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
)
builder = SectionedMaskBuilder()
item = {"prompt": "Q: Why?", "response": "A: Because."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
mask = result["loss_mask"]
assert mask[0] == 0
assert mask[-1] == 1
def test_sectioned_text(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "Hello world, this is a test."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "loss_mask" not in result
def test_sectioned_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_factory_registered():
names = MaskBuilderFactory._registry.list_names()
assert "sectioned" in names
def test_factory_create():
builder = MaskBuilderFactory.create("sectioned")
assert isinstance(builder, SectionedMaskBuilder)
def test_dpo_chat_basic(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
],
"rejected": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "5"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "chosen" in result
assert "rejected" in result
assert "chosen_mask" in result
assert "rejected_mask" in result
assert "domain" in result
assert len(result["chosen"]) == len(result["chosen_mask"])
assert len(result["rejected"]) == len(result["rejected_mask"])
assert sum(result["chosen_mask"]) > 0
assert sum(result["rejected_mask"]) > 0
def test_dpo_chosen_only_trained(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"rejected": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Go away"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert 0 in result["chosen_mask"]
assert 1 in result["chosen_mask"]
assert 0 in result["rejected_mask"]
assert 1 in result["rejected_mask"]
def test_dpo_missing_field_is_none(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
def test_grpo_basic(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "What is 2+2?"}],
"responses": ["4", "The answer is four", "Four", "2+2=4"],
"rewards": [1.0, 0.5, 0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "prompts" in result
assert "responses" in result
assert "masks" in result
assert "rewards" in result
assert len(result["responses"]) == len(result["masks"])
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
def test_grpo_response_tokens_all_trained(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A", "B"],
"rewards": [0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
masks = result["masks"]
assert all(m == 1 for m in masks)
assert len(masks) == len(result["responses"])
def test_grpo_single_reward(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A"],
"rewards": 0.9,
}
result = builder.build(item, config, chat_tokenizer)
assert result["rewards"] == [0.9]

View File

@ -1,77 +0,0 @@
import os
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
)
from tests.data.conftest import (
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_dpo_chat_config,
)
def test_default_values():
config = PipelineConfig()
assert config.version == 1
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
assert config.input.sections is None
def test_from_dict_flat():
data = {
"version": 1,
"input": {
"sections": [{"field": "messages", "action": "$role", "template": True}]
},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.sections == [
{"field": "messages", "action": "$role", "template": True}
]
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip():
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.sections == _INSTRUCTION_SECTIONS
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(temp_dir):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"}
def test_dpo_config_roundtrip(temp_dir):
config = make_dpo_chat_config()
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.sources is not None
assert "chosen" in loaded.input.sources
assert "rejected" in loaded.input.sources
assert loaded.input.sections is None

View File

@ -1,349 +0,0 @@
import json
import os
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from tests.data.conftest import (
_CHAT_SECTIONS,
_CHAT_TEMPLATE,
_INSTRUCTION_SECTIONS,
_SPECIAL_TOKENS_CONFIG,
_TEXT_SECTIONS,
make_dpo_chat_config,
make_grpo_no_template_config,
)
def test_filter_by_length():
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_full_text_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
assert meta["sequence"]["dtype"] == "int32"
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_dtype_override(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
with open(meta_path, "r") as f:
meta = json.load(f)
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "bool"
def test_dpo_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"chosen": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
],
"rejected": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Go away."},
],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_dpo_chat_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "chosen" in meta
assert "rejected" in meta
assert "chosen_mask" in meta
assert "rejected_mask" in meta
assert "sequence" not in meta
def test_grpo_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Question?",
"responses": ["Answer A", "Answer B"],
"rewards": [0.8, 0.3],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_grpo_no_template_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "prompts" in meta
assert "responses" in meta
assert "masks" in meta
assert "rewards" in meta
assert "sequence" not in meta

View File

@ -5,50 +5,46 @@ from unittest.mock import MagicMock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from astrai.inference import get_app from astrai.inference.server import app
@pytest.fixture @pytest.fixture
def client(): def client():
"""Provide a test client for the FastAPI app.""" """Provide a test client for the FastAPI app."""
_app = get_app() return TestClient(app)
_app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16", @pytest.fixture
"param_path": None, def mock_model_param():
"max_batch_size": 1, """Create a mock ModelParameter."""
"_test": True, mock_param = MagicMock()
} mock_param.model = MagicMock()
_app.state.engine = None mock_param.tokenizer = MagicMock()
return TestClient(_app) mock_param.config = MagicMock()
mock_param.config.max_len = 100
mock_param.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
mock_param.tokenizer.decode = MagicMock(return_value="mock response")
mock_param.tokenizer.stop_ids = []
mock_param.tokenizer.pad_id = 0
return mock_param
@pytest.fixture @pytest.fixture
def mock_engine(): def mock_engine():
"""Create a mock InferenceEngine.""" """Create a mock InferenceEngine."""
async def _async_gen():
yield "chunk1"
yield "chunk2"
yield "[DONE]"
mock = MagicMock() mock = MagicMock()
mock.generate.return_value = "mock response" mock.generate.return_value = "mock response"
mock.generate_async.return_value = _async_gen()
mock.get_stats.return_value = { mock.get_stats.return_value = {
"total_tasks": 0, "total_tasks": 0,
"total_tokens": 0, "total_tokens": 0,
"active_tasks": 0, "active_tasks": 0,
"waiting_queue": 0, "waiting_queue": 0,
} }
mock.tokenizer.encode.return_value = [1, 2, 3]
mock.tokenizer.decode.return_value = "mock response"
mock.tokenizer.apply_chat_template.return_value = "mock prompt"
return mock return mock
@pytest.fixture @pytest.fixture
def loaded_model(client, mock_engine): def loaded_model(mock_model_param, monkeypatch):
"""Simulate that the engine is loaded.""" """Simulate that the model is loaded."""
get_app().state.engine = mock_engine monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
return mock_engine return mock_model_param

View File

@ -1,279 +0,0 @@
"""Unit tests for inference cache components."""
import torch
from astrai.inference import (
Allocator,
KVCache,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
def make_pool(n_pages: int, page_size: int) -> PagePool:
return PagePool(Allocator(n_pages), PrefixCache(page_size))
def test_page_hash_full_page():
token_ids = list(range(256))
h = page_hash(token_ids, 0, 64)
assert isinstance(h, int)
assert h >= 0
def test_page_hash_different_page_differs():
token_ids = list(range(256))
assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64)
def test_page_pool_alloc_free_cycle():
pool = make_pool(4, 64)
a = pool.alloc()
b = pool.alloc()
assert a != b
pool.free(a)
pool.free(b)
c = pool.alloc()
assert c in (a, b)
def test_page_pool_alloc_when_full():
pool = make_pool(2, 64)
pool.alloc()
pool.alloc()
assert pool.alloc() == -1
def test_page_pool_lru_eviction():
pool = make_pool(2, 64)
p0 = pool.alloc()
p1 = pool.alloc()
pool.record(p0, list(range(64)), 0)
pool.record(p1, list(range(64, 128)), 0)
pool.free(p0)
pool.free(p1)
pool.alloc()
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
def test_page_pool_inc_ref_and_free():
pool = make_pool(2, 64)
p = pool.alloc()
pool.inc_ref(p)
assert pool._alloc._refs[p] == 2
pool.free(p)
assert pool._alloc._refs[p] == 1
pool.free(p)
assert pool._alloc._refs[p] == 0
def test_page_pool_keep_cached_realloc():
"""Free mask has priority over LRU; cached page returned only when no free pages."""
pool = make_pool(3, 64)
p0 = pool.alloc()
p1 = pool.alloc()
p2 = pool.alloc()
for p in (p0, p1, p2):
pool.record(p, [p] * 64, 0)
pool.free(p0)
pool.free(p1)
pool.free(p2)
assert pool.alloc() == p0
def test_prefix_cache_lookup_returns_hits():
token_ids = list(range(256))
pool = make_pool(16, 64)
pages = [pool.alloc() for _ in range(4)]
for i, p in enumerate(pages):
pool.record(p, token_ids, i)
pool.free(p)
hits = pool.lookup(token_ids)
assert hits == pages
def test_prefix_cache_lookup_stops_at_first_miss():
token_ids = list(range(256))
pool = make_pool(16, 64)
p0 = pool.alloc()
pool.record(p0, token_ids, 0)
pool.free(p0)
p1 = pool.alloc()
pool.record(p1, [99] * 64, 1)
pool.free(p1)
hits = pool.lookup(token_ids)
assert len(hits) == 1
assert hits[0] == p0
def test_prefix_cache_ignores_partial_last_page():
token_ids = list(range(100))
pool = make_pool(16, 64)
p = pool.alloc()
pool.record(p, token_ids, 0)
pool.free(p)
hits = pool.lookup(token_ids)
assert len(hits) == 1
def test_prefix_cache_on_evict_clears_mappings():
pool = make_pool(4, 64)
p = pool.alloc()
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
pool._prefix.evict(p)
assert p not in pool._prefix._page_to_hash
def test_prefix_cache_has_page():
pool = make_pool(4, 64)
p = pool.alloc()
assert p not in pool._prefix._page_to_hash
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
def test_task_table_set_get():
table = TaskTable(page_size=64)
table.set("task1", [0, 1, 2], 128)
assert table.get("task1") == [0, 1, 2]
assert table.get_cached("task1") == 128
def test_task_table_get_missing():
table = TaskTable(page_size=64)
assert table.get("nonexistent") == []
assert table.get_cached("nonexistent") == 0
def test_task_table_pop():
table = TaskTable(page_size=64)
table.set("task1", [0, 1], 64)
pages, cached = table.pop("task1")
assert pages == [0, 1]
assert cached == 64
assert table.get("task1") == []
def test_kv_cache_task_extend_allocates():
cache = KVCache(
n_layers=1,
n_pages=8,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [], 0)
ok = cache.task_extend("task1", 200)
assert ok
assert len(cache._table.get("task1")) == 4
def test_kv_cache_task_extend_fails_when_pool_full():
cache = KVCache(
n_layers=1,
n_pages=2,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [0, 1], 0)
ok = cache.task_extend("task1", 300)
assert not ok
def test_task_table_table_tensor():
table = TaskTable(page_size=64)
table.set("a", [0, 1], 0)
table.set("b", [2, 3, 4], 0)
t = table.table_tensor(["a", "b"], torch.device("cpu"))
assert t.shape == (2, 3)
assert t[0].tolist() == [0, 1, -1]
assert t[1].tolist() == [2, 3, 4]
def test_task_table_table_tensor_empty_input():
table = TaskTable(page_size=64)
t = table.table_tensor([], torch.device("cpu"))
assert t.numel() == 0
def test_storage_write_gather_single_page():
storage = Storage(
n_layers=2,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0]], dtype=torch.long)
k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 2)
assert torch.allclose(gk, k)
def test_storage_write_cross_page():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 8)
assert torch.allclose(gk, k)
def test_storage_gather_truncates_to_total_len():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8)
def test_storage_gather_clamps_negative_padding():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, -1]], dtype=torch.long)
gk, gv = storage.gather(0, page_table, 4)
assert gk.shape == (1, 4, 2, 8)

View File

@ -1,181 +0,0 @@
"""Unit tests for GenerateResult accumulator and InferenceEngine.generate()."""
import threading
from unittest.mock import MagicMock, patch
from astrai.inference import STOP
from astrai.inference.engine import GenerateResult
def test_result_append_single():
r = GenerateResult(count=1)
r.append("hello", 0)
assert r.results[0] == "hello"
def test_result_append_multiple_tasks():
r = GenerateResult(count=3)
r.append("a", 0)
r.append("b", 1)
r.append("c", 2)
assert r.results[0] == "a"
assert r.results[1] == "b"
assert r.results[2] == "c"
def test_result_stop_marks_complete():
r = GenerateResult(count=2)
r.append("text", 0)
r.append(STOP, 0)
r.append("more", 1)
assert r._done[0] is True
assert r._done[1] is False
assert r._completed == 1
def test_result_stop_does_not_double_count():
r = GenerateResult(count=1)
r.append(STOP, 0)
r.append(STOP, 0)
assert r._completed == 1
def test_result_pop_all_returns_and_clears():
r = GenerateResult(count=2)
r.append("a", 0)
r.append("b", 1)
out = r.pop_all()
assert len(out) == 2
assert out[0] == (0, "a")
assert out[1] == (1, "b")
assert r.pop_all() == []
def test_result_wait_blocks_until_data():
r = GenerateResult(count=1)
def delayed_append():
import time
time.sleep(0.05)
r.append("delayed", 0)
t = threading.Thread(target=delayed_append)
t.start()
ok = r.wait(timeout=5.0)
t.join()
assert ok
assert r.results[0] == "delayed"
def test_result_wait_timeout():
r = GenerateResult(count=1)
ok = r.wait(timeout=0.01)
assert not ok
def test_result_wait_completion_non_streaming():
r = GenerateResult(count=2)
def finish_later():
import time
time.sleep(0.05)
r.append(STOP, 0)
time.sleep(0.05)
r.append(STOP, 1)
t = threading.Thread(target=finish_later)
t.start()
r.wait_completion()
t.join()
assert r._completed == 2
def test_result_get_results():
r = GenerateResult(count=2)
r.append("hello", 0)
r.append("world", 1)
results = r.get_results()
assert results == ["hello", "world"]
def test_engine_generate_non_streaming_single():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "response"
mock_tokenizer.stop_ids = [0]
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
def fake_add(prompt, **kw):
cb = kw["stream_callback"]
cb("response")
cb(STOP)
instance.add_task.side_effect = fake_add
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
result = eng.generate("hello")
assert result == "response"
def test_engine_generate_streaming_yields_tokens():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "tok"
mock_tokenizer.stop_ids = [0]
callbacks_saved = []
def capture_cb(prompt, **kw):
callbacks_saved.append(kw.get("stream_callback"))
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
instance.add_task.side_effect = capture_cb
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
gen = eng.generate("hello", stream=True)
cb = callbacks_saved[0]
cb("t1")
cb("t2")
cb(STOP)
tokens = list(gen)
assert tokens == ["t1", "t2"]
def test_engine_generate_non_streaming_batch():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "r"
mock_tokenizer.stop_ids = [0]
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
def fake_add(prompt, **kw):
cb = kw["stream_callback"]
cb("r")
cb(STOP)
instance.add_task.side_effect = fake_add
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=2)
results = eng.generate(["hello", "world"])
assert results == ["r", "r"]

View File

@ -1,286 +0,0 @@
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
import json
from unittest.mock import MagicMock
import pytest
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
from astrai.inference.engine import GenerationRequest
def _make_ctx(**kwargs):
defaults = {
"resp_id": "test-123",
"created": 1000,
"model": "test-model",
"prompt_tokens": 10,
"completion_tokens": 5,
}
defaults.update(kwargs)
return GenContext(**defaults)
def _sse_payloads(events):
payloads = []
for chunk in events:
for line in chunk.strip().split("\n"):
if line.startswith("data: "):
try:
payloads.append(json.loads(line[6:]))
except json.JSONDecodeError:
pass
return payloads
class TestStopChecker:
def test_check_finds_match(self):
sc = StopChecker(["stop", "end"])
assert sc.check("hello stop world") == "stop"
def test_check_returns_none_when_no_match(self):
sc = StopChecker(["stop"])
assert sc.check("hello world") is None
def test_check_empty_sequences(self):
sc = StopChecker([])
assert sc.check("hello") is None
class TestGenContext:
def test_defaults(self):
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
assert ctx.completion_tokens == 0
def test_fields_mutable(self):
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
ctx.completion_tokens = 42
assert ctx.completion_tokens == 42
class TestStopInfo:
def test_defaults(self):
s = StopInfo()
assert s.matched is None
assert s.body == ""
assert s.yielded == ""
def test_with_values(self):
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
assert s.matched == "stop"
assert s.body == "hello stop"
assert s.yielded == "hello "
class TestOpenAIResponseBuilder:
@pytest.fixture
def builder(self):
builder = OpenAIResponseBuilder()
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hello")]
req.stop = None
req.model = "astrai"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hello"
builder.prepare(req, engine)
return builder
def test_prepare_returns_prompt_ctx_stops(self, builder):
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hi")]
req.stop = ["END"]
req.model = "gpt"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hi"
prompt, ctx, stops = builder.prepare(req, engine)
assert prompt == "Hi"
assert ctx.model == "gpt"
assert ctx.prompt_tokens == 0
assert stops == ["END"]
def test_prepare_no_stop_returns_empty_list(self, builder):
req = MagicMock()
req.messages = []
req.stop = None
req.model = "x"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = ""
_, _, stops = builder.prepare(req, engine)
assert stops == []
def test_format_stream_start(self, builder):
ctx = _make_ctx()
events = builder.format_stream_start(ctx)
payloads = _sse_payloads(events)
assert len(payloads) == 1
p = payloads[0]
assert p["object"] == "chat.completion.chunk"
assert p["choices"][0]["delta"]["role"] == "assistant"
assert p["choices"][0]["finish_reason"] is None
def test_format_chunk(self, builder):
event = builder.format_chunk("hello")
payload = json.loads(event.split("data: ", 1)[1])
assert payload["choices"][0]["delta"]["content"] == "hello"
assert payload["choices"][0]["finish_reason"] is None
def test_format_stream_end(self, builder):
ctx = _make_ctx(completion_tokens=5)
stop = StopInfo(matched="stop")
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
finish = payloads[0]
assert finish["choices"][0]["finish_reason"] == "stop"
usage = payloads[1]
assert usage["completion_tokens"] == 5
assert usage["total_tokens"] == 15
def test_format_response(self, builder):
ctx = _make_ctx()
stop = StopInfo()
resp = builder.format_response(ctx, "hello", stop)
assert resp["object"] == "chat.completion"
assert resp["choices"][0]["message"]["content"] == "hello"
assert resp["usage"]["prompt_tokens"] == 10
class TestAnthropicResponseBuilder:
@pytest.fixture
def builder(self):
builder = AnthropicResponseBuilder()
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hello")]
req.model = "claude"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hello"
req.system = None
builder.prepare(req, engine)
return builder
def test_prepare_messages(self, builder):
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hi")]
req.model = "claude"
req.system = None
req.stop_sequences = None
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hi"
prompt, ctx, stops = builder.prepare(req, engine)
assert prompt == "Hi"
assert stops == []
def test_prepare_with_stop_sequences(self, builder):
req = MagicMock()
req.messages = []
req.model = "x"
req.stop_sequences = ["stop", "end"]
req.system = None
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = ""
_, _, stops = builder.prepare(req, engine)
assert stops == ["stop", "end"]
def test_format_stream_start(self, builder):
ctx = _make_ctx(prompt_tokens=3)
events = builder.format_stream_start(ctx)
payloads = _sse_payloads(events)
assert len(payloads) == 2
assert payloads[0]["type"] == "message_start"
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
assert payloads[1]["type"] == "content_block_start"
def test_format_chunk(self, builder):
event = builder.format_chunk("tok")
payload = json.loads(event.split("data: ", 1)[1])
assert payload["type"] == "content_block_delta"
assert payload["delta"]["text"] == "tok"
def test_format_stream_end_no_stop(self, builder):
ctx = _make_ctx(completion_tokens=3)
stop = StopInfo()
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# content_block_stop, message_delta, message_stop
types = [p["type"] for p in payloads]
assert types == ["content_block_stop", "message_delta", "message_stop"]
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
ctx = _make_ctx(completion_tokens=7)
stop = StopInfo(
matched="END",
body="Hello world END extra",
yielded="Hello ",
)
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# unyielded delta, content_block_stop, message_delta, message_stop
types = [p["type"] for p in payloads]
assert types == [
"content_block_delta",
"content_block_stop",
"message_delta",
"message_stop",
]
assert payloads[0]["delta"]["text"] == "world "
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
assert payloads[2]["delta"]["stop_sequence"] == "END"
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
ctx = _make_ctx()
stop = StopInfo(
matched="END",
body="Hello END",
yielded="Hello ",
)
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# No unyielded delta (everything already sent)
types = [p["type"] for p in payloads]
assert types == ["content_block_stop", "message_delta", "message_stop"]
def test_format_response_with_stop_trims_content(self, builder):
ctx = _make_ctx()
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
resp = builder.format_response(ctx, "text STOP extra", stop)
assert resp["content"][0]["text"] == "text "
assert resp["stop_reason"] == "stop_sequence"
assert resp["stop_sequence"] == "STOP"
def test_format_response_no_stop(self, builder):
ctx = _make_ctx()
stop = StopInfo()
resp = builder.format_response(ctx, "full text", stop)
assert resp["content"][0]["text"] == "full text"
assert resp["stop_reason"] == "end_turn"
class TestGenerationRequestValidation:
def test_valid_params(self):
gr = GenerationRequest(
messages=[{"role": "user", "content": "hi"}],
top_k=50,
top_p=0.9,
temperature=0.7,
)
assert gr.top_k == 50
def test_invalid_top_p_raises(self):
with pytest.raises(ValueError, match="top_p"):
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
def test_invalid_top_k_raises(self):
with pytest.raises(ValueError, match="top_k"):
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
def test_invalid_temperature_raises(self):
with pytest.raises(ValueError, match="temperature"):
GenerationRequest(
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
)
def test_top_k_zero_valid(self):
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
assert gr.top_k == 0

View File

@ -1,127 +0,0 @@
"""Unit tests for inference sampling strategies."""
import torch
from astrai.inference.sample import (
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
def test_temperature_scalar():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TemperatureStrategy(0.5)
result = s.apply(logits.clone())
assert torch.allclose(result, logits / 0.5)
def test_temperature_skip_when_one():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TemperatureStrategy(1.0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_temperature_per_sample_tensor():
logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
s = TemperatureStrategy(torch.tensor([0.5, 0.5]))
result = s.apply(logits.clone())
assert torch.allclose(result, logits / 0.5)
def test_top_k_keeps_top():
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
s = TopKStrategy(top_k=2)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept == 2
def test_top_k_skip_when_zero():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TopKStrategy(top_k=0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_top_k_batch_tensor():
"""Each row respects its own top_k."""
logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]])
s = TopKStrategy(top_k=torch.tensor([2, 1]))
result = s.apply(logits.clone(), filter_value=-1e9)
assert (result[0] > -1e9).sum() == 2
assert (result[1] > -1e9).sum() == 1
def test_top_p_nucleus_filtering():
logits = torch.tensor([[10.0, 1.0, 1.0, 1.0, 1.0]])
s = TopPStrategy(top_p=0.5)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept >= 1
def test_top_p_skip_when_one():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TopPStrategy(top_p=1.0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_top_p_filter_all_except_max_when_zero():
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
s = TopPStrategy(top_p=0.0)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept == 1
def test_sampling_pipeline_composes_strategies():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
pipeline = SamplingPipeline(
[
TemperatureStrategy(0.8),
TopKStrategy(3),
TopPStrategy(0.95),
]
)
result = pipeline.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert 1 <= kept <= 3
def test_sampling_pipeline_sample_returns_valid_token():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
pipeline = SamplingPipeline(
[
TemperatureStrategy(0.8),
TopKStrategy(3),
TopPStrategy(0.95),
]
)
tokens = pipeline.sample(logits)
assert tokens.shape == (1,)
assert 0 <= tokens[0] < logits.size(-1)
def test_module_sample_shortcut():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
assert tokens.shape == (1,)
assert 0 <= tokens[0] < logits.size(-1)
def test_module_sample_batch():
logits = torch.tensor(
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[5.0, 4.0, 3.0, 2.0, 1.0],
]
)
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
assert tokens.shape == (2,)
for t in tokens:
assert 0 <= t < logits.size(-1)

View File

@ -1,193 +0,0 @@
"""Tests for scheduler concurrency."""
import threading
from unittest.mock import MagicMock, patch
import pytest
import torch
from astrai.inference import InferenceScheduler
@pytest.fixture
def mock_model_and_tokenizer():
"""Create mock model and tokenizer."""
mock_model = MagicMock()
mock_model.config = MagicMock()
mock_model.config.n_kv_heads = 8
mock_model.config.n_heads = 8
mock_model.config.dim = 128
mock_model.config.n_layers = 2
mock_model.config.max_len = 100
mock_model.parameters.return_value = iter(
[MagicMock(dtype=torch.float32, device=torch.device("cpu"))]
)
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
mock_tokenizer.decode.return_value = "token"
mock_tokenizer.stop_ids = [0]
mock_tokenizer.pad_id = None
return mock_model, mock_tokenizer
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
"""Test concurrent add_task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"task_ids": [], "errors": []}
lock = threading.Lock()
def add_task_worker(worker_id):
try:
for i in range(10):
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
with lock:
results["task_ids"].append(task_id)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
"""Test concurrent add and remove task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"added": [], "removed": [], "errors": []}
add_ready = threading.Event()
def add_worker():
try:
for i in range(20):
task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id)
if len(results["added"]) >= 10:
add_ready.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def remove_worker():
try:
add_ready.wait(timeout=5.0)
for task_id in results["added"][:10]:
scheduler.remove_task(task_id)
results["removed"].append(task_id)
except Exception as e:
results["errors"].append(f"Remove: {str(e)}")
add_thread = threading.Thread(target=add_worker)
remove_thread = threading.Thread(target=remove_worker)
add_thread.start()
remove_thread.start()
add_thread.join()
remove_thread.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
"""Test concurrent get_stats operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"stats": [], "errors": []}
started = threading.Event()
stats_done = threading.Event()
def add_tasks():
try:
for i in range(20):
scheduler.add_task(f"prompt {i}")
started.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def get_stats():
try:
started.wait(timeout=5.0)
for _ in range(50):
stats = scheduler.get_stats()
results["stats"].append(stats)
stats_done.set()
except Exception as e:
results["errors"].append(f"Get stats: {str(e)}")
add_thread = threading.Thread(target=add_tasks)
stats_thread = threading.Thread(target=get_stats)
add_thread.start()
stats_thread.start()
add_thread.join()
stats_done.wait(timeout=5.0)
scheduler.stop()
stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
"""Tasks whose entire prompt is cached skip the prefill phase."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
scheduler.stop()
assert task_id.startswith("task_")

View File

@ -0,0 +1,320 @@
"""Tests for scheduler concurrency."""
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
from astrai.inference.scheduler import (
InferenceScheduler,
PrefixCacheManager,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find_longest_prefix([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id)
# Find after insert
cache.find_longest_prefix(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
@pytest.fixture
def mock_model_and_tokenizer():
"""Create mock model and tokenizer."""
mock_model = MagicMock()
mock_model.config = MagicMock()
mock_model.config.n_kv_heads = 8
mock_model.config.n_heads = 8
mock_model.config.dim = 128
mock_model.config.n_layers = 2
mock_model.config.max_len = 100
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
mock_tokenizer.decode.return_value = "token"
mock_tokenizer.stop_ids = [0]
mock_tokenizer.pad_id = None
return mock_model, mock_tokenizer
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
"""Test concurrent add_task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"task_ids": [], "errors": []}
lock = threading.Lock()
def add_task_worker(worker_id):
try:
for i in range(10):
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
with lock:
results["task_ids"].append(task_id)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
# Let some tasks be processed
time.sleep(0.1)
scheduler.stop()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
"""Test concurrent add and remove task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"added": [], "removed": [], "errors": []}
def add_worker():
try:
for i in range(20):
task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id)
time.sleep(0.001)
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def remove_worker():
try:
time.sleep(0.05) # Wait for some tasks to be added
for task_id in results["added"][:10]:
scheduler.remove_task(task_id)
results["removed"].append(task_id)
except Exception as e:
results["errors"].append(f"Remove: {str(e)}")
add_thread = threading.Thread(target=add_worker)
remove_thread = threading.Thread(target=remove_worker)
add_thread.start()
remove_thread.start()
time.sleep(0.2)
scheduler.stop()
add_thread.join()
remove_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
"""Test concurrent get_stats operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"stats": [], "errors": []}
def add_tasks():
try:
for i in range(20):
scheduler.add_task(f"prompt {i}")
time.sleep(0.001)
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def get_stats():
try:
for _ in range(50):
stats = scheduler.get_stats()
results["stats"].append(stats)
time.sleep(0.001)
except Exception as e:
results["errors"].append(f"Get stats: {str(e)}")
add_thread = threading.Thread(target=add_tasks)
stats_thread = threading.Thread(target=get_stats)
add_thread.start()
stats_thread.start()
time.sleep(0.3)
scheduler.stop()
add_thread.join()
stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50
# Verify stats are consistent
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix
cache.insert((1, 2, 3), slot=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"

Some files were not shown because too many files have changed in this diff Show More