Compare commits
90 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
3057741de9 | |
|
|
acd1103bd0 | |
|
|
dc7d2cfbca | |
|
|
b36a78c612 | |
|
|
985d940db6 | |
|
|
5e73ca20aa | |
|
|
438dc10391 | |
|
|
615ba5d8ef | |
|
|
02a7cb9fa0 | |
|
|
9fe2121743 | |
|
|
0422d6d38e | |
|
|
9b416c1bbb | |
|
|
d6899100ac | |
|
|
0deee48602 | |
|
|
746a1475b2 | |
|
|
01ce1fb9e3 | |
|
|
14f83cbdac | |
|
|
dbe5891201 | |
|
|
2a65c3314c | |
|
|
1c2ff05a6d | |
|
|
31ae2deeba | |
|
|
69207e2c57 | |
|
|
138c5bcc08 | |
|
|
a923e0a23a | |
|
|
f521a30b22 | |
|
|
d4451f6afb | |
|
|
a3275423a4 | |
|
|
b37c3d000c | |
|
|
6031020e37 | |
|
|
c424dfc293 | |
|
|
3a28e52e98 | |
|
|
e371908b54 | |
|
|
7c99da155c | |
|
|
629e72385b | |
|
|
0a708fff24 | |
|
|
6e150ea6d0 | |
|
|
cb8dcb97ea | |
|
|
2d5dc93b3d | |
|
|
4145d35e3c | |
|
|
34c6c45bd6 | |
|
|
e9def84ce7 | |
|
|
836e02a166 | |
|
|
b558e61f63 | |
|
|
65ab69543b | |
|
|
1d26aa2e93 | |
|
|
a548d4553e | |
|
|
dd1b39f435 | |
|
|
94d6e713e9 | |
|
|
47c37e4876 | |
|
|
737585a32a | |
|
|
a4688021bf | |
|
|
7df6eb9211 | |
|
|
82a3f2626f | |
|
|
7fa69572c0 | |
|
|
3ab4f237e5 | |
|
|
8cbf3f36e2 | |
|
|
0594ce1017 | |
|
|
ff509ff39f | |
|
|
785d65436c | |
|
|
64be81b7b3 | |
|
|
45479b5731 | |
|
|
e0a3337c22 | |
|
|
812238060b | |
|
|
14b0d56197 | |
|
|
6c8533f1d2 | |
|
|
2c2697390d | |
|
|
7621f05d3f | |
|
|
10ebd7211f | |
|
|
42a391f0fb | |
|
|
97c7ac0f4f | |
|
|
8f1b32f2b6 | |
|
|
c241a5dcef | |
|
|
44dab27fdc | |
|
|
a44fd22a99 | |
|
|
8a11a7d444 | |
|
|
1d54491809 | |
|
|
ad9f4d9cf6 | |
|
|
e1638a7ade | |
|
|
f91bfee33e | |
|
|
d7a7f570ed | |
|
|
7dea929788 | |
|
|
026d1fc33d | |
|
|
7242eedbf4 | |
|
|
04c0dc7a47 | |
|
|
48a53121ba | |
|
|
0ba8c70ce1 | |
|
|
3d12a03909 | |
|
|
c169659611 | |
|
|
e12f1a7ee5 | |
|
|
ef25efffa2 |
|
|
@ -2,7 +2,7 @@
|
||||||
name: Bug report
|
name: Bug report
|
||||||
about: Create a report to help us improve
|
about: Create a report to help us improve
|
||||||
title: "[BUG]"
|
title: "[BUG]"
|
||||||
labels: enhancement
|
labels: bug
|
||||||
assignees: ''
|
assignees: ''
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,9 @@ Please delete options that are not relevant.
|
||||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
|
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
|
||||||
|
|
||||||
## Checklist:
|
## Checklist:
|
||||||
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check --fix .`)
|
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
|
||||||
- [ ] I have performed a self-review of my own code
|
- [ ] I have performed a self-review of my own code
|
||||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
- [ ] Code is self-documenting (no unnecessary comments)
|
||||||
- [ ] I have made corresponding changes to the documentation
|
- [ ] I have made corresponding changes to the documentation
|
||||||
- [ ] My changes generate no new warnings
|
- [ ] My changes generate no new warnings
|
||||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||||
|
|
|
||||||
128
CONTRIBUTING.md
128
CONTRIBUTING.md
|
|
@ -1,68 +1,100 @@
|
||||||
# Contributing to AstrAI
|
# Contributing to AstrAI
|
||||||
|
|
||||||
Thank you for your interest in contributing to AstrAI! This document provides guidelines and steps for contributing.
|
Thank you for your interest in contributing! This document provides step-by-step guidelines.
|
||||||
|
|
||||||
## How to Contribute
|
## Quick Start
|
||||||
|
|
||||||
### Reporting Issues
|
```bash
|
||||||
If you encounter a bug or have a feature request, please open an issue on GitHub. Include as much detail as possible:
|
git clone https://github.com/your-username/AstrAI.git
|
||||||
- A clear description of the problem or request.
|
cd AstrAI
|
||||||
- Steps to reproduce (for bugs).
|
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
||||||
- Your environment (Python version, OS, etc.).
|
```
|
||||||
|
|
||||||
### Submitting Changes
|
## Before You Commit
|
||||||
1. **Fork** the repository.
|
|
||||||
2. **Clone** your fork:
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/your-username/AstrAI.git
|
|
||||||
cd AstrAI
|
|
||||||
```
|
|
||||||
3. **Create a feature branch**:
|
|
||||||
```bash
|
|
||||||
git checkout -b feature/your-feature-name
|
|
||||||
```
|
|
||||||
4. **Make your changes**. Follow the code style guidelines below.
|
|
||||||
5. **Commit your changes** with a descriptive commit message:
|
|
||||||
```bash
|
|
||||||
git commit -m "Add: brief description of the change"
|
|
||||||
```
|
|
||||||
6. **Push** to your fork:
|
|
||||||
```bash
|
|
||||||
git push origin feature/your-feature-name
|
|
||||||
```
|
|
||||||
7. **Open a Pull Request** (PR) against the `main` branch of the upstream repository.
|
|
||||||
|
|
||||||
## Code Style
|
Run the following checks **in order** — CI will reject if any fail.
|
||||||
|
|
||||||
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
|
### 1. Format
|
||||||
|
|
||||||
- Run Ruff to format and lint (requires conda environment `nlp`):
|
```bash
|
||||||
```bash
|
ruff format .
|
||||||
conda run -n nlp ruff format .
|
```
|
||||||
conda run -n nlp ruff check --fix .
|
|
||||||
```
|
|
||||||
- The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`).
|
|
||||||
|
|
||||||
## Testing
|
> **Note**: `ruff format` may rename parameters (e.g. `mask` → `attn_mask`).
|
||||||
|
> Always review the diff after formatting.
|
||||||
|
|
||||||
If you add or modify functionality, please include appropriate tests.
|
### 2. Import sorting
|
||||||
|
|
||||||
- Run the test suite with:
|
```bash
|
||||||
```bash
|
ruff check . --select I
|
||||||
conda run -n nlp python -u -m pytest
|
```
|
||||||
```
|
|
||||||
- Ensure all tests pass before submitting your PR.
|
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff check . --select I --fix .
|
||||||
|
ruff format . # re-format after fix
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Run tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -u -m pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
|
||||||
|
|
||||||
|
### 4. (Optional) Full pre-commit check
|
||||||
|
|
||||||
|
If you have Git Bash available:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/pre_commit.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs format check, import sort check, and tests in one go.
|
||||||
|
|
||||||
|
## Commit Style
|
||||||
|
|
||||||
|
```
|
||||||
|
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
|
||||||
|
|
||||||
|
- bullet point body (each ~60 chars)
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
|
||||||
|
- **Subject line** ends with no period.
|
||||||
|
- **Body** uses bullet points starting with `-`.
|
||||||
|
- No `(scope)` parentheses.
|
||||||
|
|
||||||
|
## Common Issues
|
||||||
|
|
||||||
|
| Problem | Cause | Fix |
|
||||||
|
|---------|-------|-----|
|
||||||
|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
|
||||||
|
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
|
||||||
|
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
|
||||||
|
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
|
||||||
|
|
||||||
|
## Submitting Changes
|
||||||
|
|
||||||
|
1. Fork the repo.
|
||||||
|
2. Create a feature branch: `git checkout -b feat/my-feature`
|
||||||
|
3. Make changes following the steps above.
|
||||||
|
4. Commit with the commit style above.
|
||||||
|
5. Push: `git push origin feat/my-feature`
|
||||||
|
6. Open a Pull Request against `main`.
|
||||||
|
|
||||||
## Code Review
|
## Code Review
|
||||||
|
|
||||||
All submissions will be reviewed. We may request changes or discuss alternatives. Please be responsive to feedback.
|
- All PRs are reviewed. We may request changes.
|
||||||
|
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
|
||||||
|
- Ensure all tests pass.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
By contributing, you agree that your contributions will be licensed under the same [GPL-3.0 License](LICENSE) that covers the project.
|
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
If you have any questions, feel free to ask in the [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
|
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
|
||||||
|
|
||||||
Happy contributing!
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# AstrAI Dockerfile - Multi-stage Build (Optimized)
|
# AstrAI Dockerfile - Multi-stage Build (Optimized)
|
||||||
|
|
||||||
# Build stage - use base image with minimal build tools
|
# Build stage - use base image with minimal build tools
|
||||||
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS builder
|
FROM ubuntu:24.04 AS builder
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|
@ -18,7 +18,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||||
RUN python3.12 -m venv --copies /opt/venv
|
RUN python3.12 -m venv --copies /opt/venv
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
# Copy source code and install dependencies
|
# Copy source code and install (deps read from pyproject.toml)
|
||||||
COPY astrai/ ./astrai/
|
COPY astrai/ ./astrai/
|
||||||
COPY pyproject.toml .
|
COPY pyproject.toml .
|
||||||
RUN pip install --no-cache-dir --upgrade pip \
|
RUN pip install --no-cache-dir --upgrade pip \
|
||||||
|
|
@ -26,13 +26,14 @@ RUN pip install --no-cache-dir --upgrade pip \
|
||||||
--extra-index-url https://download.pytorch.org/whl/cu126
|
--extra-index-url https://download.pytorch.org/whl/cu126
|
||||||
|
|
||||||
# Production stage
|
# Production stage
|
||||||
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS production
|
FROM ubuntu:24.04 AS production
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install Python 3.12 runtime
|
# Install Python 3.12 runtime and healthcheck dependency
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
python3.12 \
|
python3.12 \
|
||||||
|
curl \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy virtual environment from builder
|
# Copy virtual environment from builder
|
||||||
|
|
|
||||||
45
README.md
45
README.md
|
|
@ -78,15 +78,28 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type seq \
|
|
||||||
--data_root_path /path/to/dataset \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path /path/to/model \
|
--nprocs=4 \
|
||||||
--batch_size 4 \
|
--parallel_mode=ddp \
|
||||||
--accumulation_steps 8 \
|
--train_type=seq \
|
||||||
--max_lr 3e-4 \
|
--data_root_path=/path/to/dataset \
|
||||||
--warmup_steps 1000 \
|
--param_path=/path/to/model \
|
||||||
--n_epoch 1
|
--batch_per_device=4 \
|
||||||
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.9 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
|
|
@ -96,8 +109,8 @@ Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py \
|
python scripts/tools/generate.py \
|
||||||
--param_path /path/to/model \
|
--param_path /path/to/model \
|
||||||
--input_json_file /path/to/input.json \
|
--input_json_file /path/to/input.jsonl \
|
||||||
--output_json_file /path/to/output.json
|
--output_json_file /path/to/output.jsonl
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
@ -201,16 +214,18 @@ python scripts/demo/generate_batch.py
|
||||||
python scripts/demo/generate_ar.py
|
python scripts/demo/generate_ar.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd).
|
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
||||||
| Document | Description |
|
| Document | Description |
|
||||||
|----------|-------------|
|
|----------|-------------|
|
||||||
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
|
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
|
||||||
| [Design Document](./assets/docs/design.md) | Framework architecture & module design |
|
| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns |
|
||||||
| [Data Flow](./assets/docs/dataflow.md) | Data processing pipeline details |
|
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
||||||
| [Model Introduction](./assets/docs/introduction.md) | Model architecture & technical details |
|
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
||||||
|
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
||||||
|
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
|
||||||
|
|
||||||
### Contributing
|
### Contributing
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -84,15 +84,28 @@ python scripts/demo/download.py
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type seq \
|
|
||||||
--data_root_path /path/to/dataset \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path /path/to/model \
|
--nprocs=4 \
|
||||||
--batch_size 4 \
|
--parallel_mode=ddp \
|
||||||
--accumulation_steps 8 \
|
--train_type=seq \
|
||||||
--max_lr 3e-4 \
|
--data_root_path=/path/to/dataset \
|
||||||
--warmup_steps 1000 \
|
--param_path=/path/to/model \
|
||||||
--n_epoch 1
|
--batch_per_device=4 \
|
||||||
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.9 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md)。
|
完整参数列表见[参数说明](./params.md)。
|
||||||
|
|
@ -102,8 +115,8 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py \
|
python scripts/tools/generate.py \
|
||||||
--param_path /path/to/model \
|
--param_path /path/to/model \
|
||||||
--input_json_file /path/to/input.json \
|
--input_json_file /path/to/input.jsonl \
|
||||||
--output_json_file /path/to/output.json
|
--output_json_file /path/to/output.jsonl
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
@ -207,16 +220,18 @@ python scripts/demo/generate_batch.py
|
||||||
python scripts/demo/generate_ar.py
|
python scripts/demo/generate_ar.py
|
||||||
```
|
```
|
||||||
|
|
||||||
观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。
|
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
|
||||||
|
|
||||||
### 文档
|
### 文档
|
||||||
|
|
||||||
| 文档 | 说明 |
|
| 文档 | 说明 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| [参数说明](./params.md) | 训练与推理参数配置 |
|
| [参数说明](./params.md) | 训练与推理参数配置 |
|
||||||
| [设计文档](./design.md) | 系统架构与模块设计 |
|
| [架构文档](./architecture.md) | 系统架构、类图与设计模式 |
|
||||||
| [数据流程](./dataflow.md) | 数据处理管道详解 |
|
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
||||||
| [模型介绍](./introduction.md) | 模型架构与技术细节 |
|
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
||||||
|
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
||||||
|
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
|
||||||
|
|
||||||
### 贡献
|
### 贡献
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,237 +1,64 @@
|
||||||
# AstrAI Data Flow Documentation
|
# Data Flow
|
||||||
|
|
||||||
This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
|
This document describes the data pipeline: from raw text to model input tensors.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
AstrAI adopts a modular design with the following main components:
|
```
|
||||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
|
||||||
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
|
||||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
|
||||||
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
|
|
||||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
|
||||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
|
||||||
- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors
|
|
||||||
|
|
||||||
## Data Flow Diagram
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
flowchart LR
|
|
||||||
subgraph A[Data Preparation]
|
|
||||||
direction TB
|
|
||||||
A1[Raw Text] --> A2[AutoTokenizer]
|
|
||||||
A2 --> A3[Tokenized .h5 files]
|
|
||||||
A3 --> A4[BaseDataset]
|
|
||||||
A4 --> A5[ResumableDistributedSampler]
|
|
||||||
A5 --> A6[DataLoader]
|
|
||||||
end
|
|
||||||
|
|
||||||
subgraph B[Training]
|
|
||||||
direction TB
|
|
||||||
B1[DataLoader] --> B2[BaseStrategy]
|
|
||||||
B2 --> B3[Transformer Forward]
|
|
||||||
B3 --> B4[Loss + Backward]
|
|
||||||
B4 --> B5[Gradient Accumulation]
|
|
||||||
B5 -->|every accum_steps| B6[Optimizer Step]
|
|
||||||
B6 --> B7[LR Scheduler]
|
|
||||||
B7 -->|next batch| B2
|
|
||||||
B6 --> B8[CheckpointCallback]
|
|
||||||
end
|
|
||||||
|
|
||||||
subgraph C[Inference]
|
|
||||||
direction TB
|
|
||||||
C1[Checkpoint] --> C2[AutoModel]
|
|
||||||
C1 --> C3[AutoTokenizer]
|
|
||||||
C2 --> C4[InferenceEngine]
|
|
||||||
C3 --> C4
|
|
||||||
C4 --> C5[InferenceScheduler]
|
|
||||||
C5 --> C6[Transformer Forward]
|
|
||||||
C6 --> C7[sample]
|
|
||||||
C7 --> C8{End?}
|
|
||||||
C8 -->|No| C6
|
|
||||||
C8 -->|Yes| C9[Generated Text]
|
|
||||||
end
|
|
||||||
|
|
||||||
A --> B
|
|
||||||
B --> C
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Detailed Module Descriptions
|
## Data Preparation
|
||||||
|
|
||||||
### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`)
|
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||||
|
|
||||||
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
|
|
||||||
- **`Checkpoint`**: Encapsulates model state dict + epoch + iteration; uses safetensors
|
|
||||||
|
|
||||||
### 2. Dataset Module
|
|
||||||
|
|
||||||
#### 2.1 Dataset (`dataset.py`)
|
|
||||||
- **`BaseDataset`**: Abstract base class for windowed sequence sampling
|
|
||||||
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
|
|
||||||
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
|
|
||||||
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO)
|
|
||||||
|
|
||||||
#### 2.2 Sampler (`sampler.py`)
|
|
||||||
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
|
|
||||||
|
|
||||||
### 3. Model Module
|
|
||||||
|
|
||||||
#### 3.1 Transformer / AutoModel
|
|
||||||
- **`AutoModel`**: Base class with `from_pretrained()` / `save_pretrained()`
|
|
||||||
- **`Transformer`**: Decoder-only architecture, registered via `@AutoModel.register('transformer')`
|
|
||||||
- Embedding → N×DecoderBlock → RMSNorm → Linear lm_head
|
|
||||||
- RoPE position encoding, optional weight tying
|
|
||||||
|
|
||||||
#### 3.2 Submodules (`module.py`)
|
|
||||||
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
|
||||||
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
|
||||||
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
|
||||||
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
|
|
||||||
- **`RMSNorm`**: Layer normalization
|
|
||||||
|
|
||||||
### 4. Training Module
|
|
||||||
|
|
||||||
#### 4.1 Training Context (`train_context.py`)
|
|
||||||
- **`TrainContext`**: Dataclass holding model, optimizer, dataloader, strategy, scheduler, checkpoint state
|
|
||||||
- **`TrainContextBuilder`**: Builder pattern — takes checkpoint for resume, builds all components
|
|
||||||
|
|
||||||
#### 4.2 Trainer (`trainer.py`)
|
|
||||||
|
|
||||||
The training loop is nested: **epoch** → **batch** (with step phase interspersed):
|
|
||||||
|
|
||||||
```
|
```
|
||||||
on_train_begin
|
StoreFactory.create("h5") → H5Store
|
||||||
on_epoch_begin
|
StoreFactory.create("bin") → MmapStore
|
||||||
for each accumulation window of batches: ← step phase
|
|
||||||
on_step_begin
|
|
||||||
for each batch in window: ← batch phase
|
|
||||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
|
||||||
iteration += 1
|
|
||||||
on_step_end
|
|
||||||
optimizer.step() → zero_grad
|
|
||||||
|
|
||||||
on_epoch_end
|
|
||||||
on_train_end
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Key points:
|
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||||
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
|
|
||||||
- `on_batch_*` fires every batch, wrapping loss computation
|
|
||||||
- `GradientClippingCallback` fires on `on_step_end`
|
|
||||||
- LR scheduler steps inline (no `SchedulerCallback` class)
|
|
||||||
|
|
||||||
#### 4.3 Strategy (`strategy.py`)
|
## Data Keys by Training Type
|
||||||
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
|
||||||
- **`SFTStrategy`**: Supervised fine-tuning with loss masking
|
|
||||||
- **`DPOStrategy`**: Direct Preference Optimization with reference model
|
|
||||||
- **`GRPOStrategy`**: Group Relative Policy Optimization with clipped ratio
|
|
||||||
|
|
||||||
#### 4.4 Scheduler (`schedule.py`)
|
| Type | Storage Keys |
|
||||||
- **`CosineScheduler`**: Cosine decay + linear warmup
|
|------|-------------|
|
||||||
- **`SGDRScheduler`**: Cosine annealing with warm restarts
|
| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) |
|
||||||
- Created by `SchedulerFactory` and bound to optimizer
|
| `sft` | `sequence`, `loss_mask` |
|
||||||
|
| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` |
|
||||||
|
| `grpo` | `prompts`, `responses`, `masks`, `rewards` |
|
||||||
|
|
||||||
#### 4.5 Callbacks
|
## Dataset Architecture
|
||||||
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
|
||||||
- **`ProgressBarCallback`**: tqdm progress display
|
|
||||||
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
|
||||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
|
|
||||||
|
|
||||||
### 5. Inference Module
|
|
||||||
|
|
||||||
#### 5.1 Inference Engine (`engine.py`)
|
|
||||||
- **`InferenceEngine`**: Facade over scheduler; provides `generate()`, `generate_with_request()`, `generate_async()`
|
|
||||||
- Accepts `prompt: str | List[str]`, returns generator (stream) or string (non-stream)
|
|
||||||
|
|
||||||
#### 5.2 Scheduler 4-Phase Loop (`scheduler.py`)
|
|
||||||
|
|
||||||
Background thread runs continuously:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
1. Cleanup → Remove finished tasks, free KV cache pages
|
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
|
||||||
2. Refill → Pop from waiting_queue, alloc pages, add to active
|
→ BaseDataset.load(load_path, storage_type=None)
|
||||||
3. Prefill → Group active tasks by prompt_len, run full forward pass
|
→ detect_format(load_path)
|
||||||
4. Decode → Pick largest same-position group, run single-token forward
|
→ StoreFactory.create(storage_type)
|
||||||
|
→ Store.load(load_path)
|
||||||
|
→ H5Store._normalize() / MmapStore._normalize()
|
||||||
|
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||||
|
→ BaseDataset.__getitem__(idx)
|
||||||
|
→ get_index(idx) → [begin, end)
|
||||||
|
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
|
||||||
```
|
```
|
||||||
|
|
||||||
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)
|
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
|
||||||
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache
|
|
||||||
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers
|
|
||||||
- **`sample()`**: Temperature → top-k → top-p → multinomial
|
|
||||||
|
|
||||||
#### 5.3 Server (`server.py`)
|
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
|
||||||
- FastAPI with OpenAI `/v1/chat/completions` and Anthropic `/v1/messages` endpoints
|
|
||||||
- Streaming via SSE, health check at `/health`, stats at `/stats`
|
|
||||||
|
|
||||||
### 6. Tokenizer Module
|
## Sampler
|
||||||
|
|
||||||
- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template`
|
`ResumableDistributedSampler` supports checkpoint-aware distributed sampling:
|
||||||
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
|
|
||||||
|
|
||||||
### 7. Factory & Parallel
|
- Tracks `start_epoch` / `start_iter` for resume
|
||||||
|
- Shuffle via `torch.Generator(seed + epoch)`
|
||||||
|
- Per-replica index slicing for DDP
|
||||||
|
|
||||||
- **`Registry` / `BaseFactory`**: Decorator-based component registration
|
## DataLoader
|
||||||
- **`spawn_parallel_fn`**: Multi-process DDP launcher with NCCL backend
|
|
||||||
- **`ParallelModel` / `ColumnParallelLinear` / `RowParallelLinear`**: Tensor model parallelism
|
|
||||||
|
|
||||||
## Training Data Flow — Detailed Steps
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
1. **Data Preparation**
|
> Document Update Time: 2026-05-30
|
||||||
- Raw text → token IDs via `AutoTokenizer.encode()`
|
|
||||||
- Save as `.h5` files (groups of tensor lists per data key)
|
|
||||||
|
|
||||||
2. **Dataset Loading**
|
|
||||||
- `BaseDataset.load()` calls `load_h5()`, builds `MultiSegmentFetcher`
|
|
||||||
- Sliding window of `window_size` with `stride` determines sample boundaries
|
|
||||||
|
|
||||||
3. **Sampling & Batching**
|
|
||||||
- `ResumableDistributedSampler` produces shuffled index sequences
|
|
||||||
- `DataLoader` fetches `[batch_size, window_size]` tensors via `__getitem__`
|
|
||||||
|
|
||||||
4. **Strategy Forward**
|
|
||||||
- Strategy receives batch, calls `Transformer.forward()` for logits
|
|
||||||
- Computes task-specific loss (cross-entropy, DPO, GRPO)
|
|
||||||
|
|
||||||
5. **Backward & Accumulation**
|
|
||||||
- `loss = raw_loss / accumulation_steps`
|
|
||||||
- `loss.backward()` accumulates gradients
|
|
||||||
- Every `accumulation_steps` batches: `optimizer.step()` → `zero_grad()`
|
|
||||||
- Every batch: `scheduler.step()` updates learning rate
|
|
||||||
|
|
||||||
6. **Checkpoint**
|
|
||||||
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
|
|
||||||
- Does NOT save optimizer/scheduler state (resume resets those)
|
|
||||||
|
|
||||||
## Inference Data Flow — Detailed Steps
|
|
||||||
|
|
||||||
1. **Model Loading**
|
|
||||||
- `AutoModel.from_pretrained(path)` loads weights from safetensors
|
|
||||||
- `torch.inference_mode()` wraps generation
|
|
||||||
|
|
||||||
2. **Prompt Construction**
|
|
||||||
- Messages → `apply_chat_template(messages, tokenize=False)` → prompt string
|
|
||||||
- `tokenizer.encode(prompt)` → token IDs (truncated to `max_prompt_len`)
|
|
||||||
|
|
||||||
3. **Continuous Batching Loop**
|
|
||||||
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
|
|
||||||
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages
|
|
||||||
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
|
|
||||||
- **Decode**: Pick position group with most tasks, single-token forward:
|
|
||||||
- Model forward → `logits` → `sample()` → next token ID
|
|
||||||
- Append to `output_ids`, update `output_tokens`
|
|
||||||
- `PagePool.task_alloc()` allocates pages as needed
|
|
||||||
- `stream_callback(token)` for streaming clients
|
|
||||||
|
|
||||||
4. **Output**
|
|
||||||
- `tokenizer.decode(output_ids)` → text
|
|
||||||
- Return to caller (streaming: token-by-token; non-streaming: complete string)
|
|
||||||
|
|
||||||
## Checkpoint & Serialization
|
|
||||||
|
|
||||||
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted.
|
|
||||||
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
|
|
||||||
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
|
||||||
|
|
|
||||||
|
|
@ -1,779 +0,0 @@
|
||||||
## 1. Why I Created This Project
|
|
||||||
|
|
||||||
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
|
|
||||||
|
|
||||||
Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, and the training code is open source!
|
|
||||||
|
|
||||||
## 2. System Architecture
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
classDiagram
|
|
||||||
namespace config {
|
|
||||||
class ModelConfig {
|
|
||||||
+int vocab_size
|
|
||||||
+int dim
|
|
||||||
+int n_layers
|
|
||||||
+float norm_eps
|
|
||||||
+int dim_ffn
|
|
||||||
+bool tie_weight
|
|
||||||
+int max_len
|
|
||||||
+float rope_theta
|
|
||||||
+int n_heads
|
|
||||||
+int n_kv_heads
|
|
||||||
+bool use_qk_norm
|
|
||||||
+bool use_gated_attention
|
|
||||||
+load(config_path) ModelConfig
|
|
||||||
+save(config_path)
|
|
||||||
}
|
|
||||||
|
|
||||||
class TrainConfig {
|
|
||||||
+nn.Module model
|
|
||||||
+str strategy
|
|
||||||
+Dataset dataset
|
|
||||||
+Callable optimizer_fn
|
|
||||||
+Callable scheduler_fn
|
|
||||||
+int n_epoch
|
|
||||||
+int batch_size
|
|
||||||
+int accumulation_steps
|
|
||||||
+float max_grad_norm
|
|
||||||
+int start_epoch
|
|
||||||
+int start_batch
|
|
||||||
+str ckpt_dir
|
|
||||||
+int ckpt_interval
|
|
||||||
+int random_seed
|
|
||||||
+int num_workers
|
|
||||||
+int prefetch_factor
|
|
||||||
+bool pin_memory
|
|
||||||
+int nprocs
|
|
||||||
+str backend
|
|
||||||
+str master_addr
|
|
||||||
+str master_port
|
|
||||||
+Callable parallel_wrapper
|
|
||||||
+Callable state_dict_fn
|
|
||||||
+str device_type
|
|
||||||
+dict extra_kwargs
|
|
||||||
+validate()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace dataset {
|
|
||||||
class BaseDataset {
|
|
||||||
+int window_size
|
|
||||||
+int stride
|
|
||||||
+BaseStorage storage
|
|
||||||
+load(load_path, storage_type, tokenizer)
|
|
||||||
+__getitem__(index)
|
|
||||||
+__len__()
|
|
||||||
}
|
|
||||||
|
|
||||||
class SEQDataset {
|
|
||||||
+__getitem__(index) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class SFTDataset {
|
|
||||||
+__getitem__(index) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class DPODataset {
|
|
||||||
+__getitem__(index) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class GRPODataset {
|
|
||||||
+__getitem__(index) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseSegmentFetcher {
|
|
||||||
+List[Tensor] segments
|
|
||||||
+List[int] cum_lengths
|
|
||||||
+int total_length
|
|
||||||
+fetch_data(begin_idx, end_idx) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseStorage {
|
|
||||||
+MultiSegmentFetcher _fetcher
|
|
||||||
+keys (property)
|
|
||||||
+load(load_path, tokenizer)
|
|
||||||
+fetch(begin, end, keys)
|
|
||||||
+__len__()
|
|
||||||
}
|
|
||||||
|
|
||||||
class H5Storage {
|
|
||||||
+load(load_path, tokenizer)
|
|
||||||
+fetch(begin, end, keys) Dict
|
|
||||||
+keys() List
|
|
||||||
}
|
|
||||||
|
|
||||||
class JSONStorage {
|
|
||||||
+load(load_path, tokenizer)
|
|
||||||
+fetch(begin, end, keys) Dict
|
|
||||||
+keys() List
|
|
||||||
}
|
|
||||||
|
|
||||||
class MultiSegmentFetcher {
|
|
||||||
+Dict multi_fetchers
|
|
||||||
+List multi_keys
|
|
||||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
|
||||||
+fetch_data(begin_idx, end_idx) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
|
||||||
+int epoch
|
|
||||||
+int iter
|
|
||||||
}
|
|
||||||
|
|
||||||
class DatasetFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
|
||||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace serialization {
|
|
||||||
class Checkpoint {
|
|
||||||
+dict state_dict
|
|
||||||
+int epoch
|
|
||||||
+int iteration
|
|
||||||
+save(save_dir)
|
|
||||||
+load(save_dir) Checkpoint
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace model {
|
|
||||||
class AutoModel {
|
|
||||||
+ModelConfig config
|
|
||||||
+Registry _registry
|
|
||||||
+register(model_type) decorator
|
|
||||||
+get_component_class(model_type) Type
|
|
||||||
+from_pretrained(path, disable_random_init) nn.Module
|
|
||||||
+save_pretrained(save_directory)
|
|
||||||
+to(*args, **kwargs) Self
|
|
||||||
}
|
|
||||||
|
|
||||||
class Transformer {
|
|
||||||
+ModelConfig config
|
|
||||||
+RotaryEmbedding rotary_embedding
|
|
||||||
+Embedding embed_tokens
|
|
||||||
+ModuleList layers
|
|
||||||
+RMSNorm norm
|
|
||||||
+Linear lm_head
|
|
||||||
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor
|
|
||||||
+load_state_dict(state_dict)
|
|
||||||
+state_dict()
|
|
||||||
}
|
|
||||||
|
|
||||||
class DecoderBlock {
|
|
||||||
+GQA attention
|
|
||||||
+RMSNorm input_norm
|
|
||||||
+MLP mlp
|
|
||||||
+RMSNorm post_attention_norm
|
|
||||||
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class GQA {
|
|
||||||
+int n_heads
|
|
||||||
+int n_kv_heads
|
|
||||||
+int head_dim
|
|
||||||
+Linear q_proj, k_proj, v_proj, o_proj
|
|
||||||
+RMSNorm q_norm, k_norm
|
|
||||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class MLA {
|
|
||||||
+int n_heads
|
|
||||||
+int n_kv_heads
|
|
||||||
+int head_dim
|
|
||||||
+int kv_lora_rank
|
|
||||||
+int qk_nope_head_dim
|
|
||||||
+int qk_rope_head_dim
|
|
||||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
|
||||||
+Linear o_proj
|
|
||||||
+RMSNorm kv_norm
|
|
||||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class MLP {
|
|
||||||
+Linear up, gate, down
|
|
||||||
+forward(x) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class RMSNorm {
|
|
||||||
+Parameter weight
|
|
||||||
+float norm_eps
|
|
||||||
+forward(x) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class Linear {
|
|
||||||
+Parameter weight
|
|
||||||
+Parameter bias
|
|
||||||
+forward(x) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class RotaryEmbedding {
|
|
||||||
+int dim
|
|
||||||
+int max_len
|
|
||||||
+float base
|
|
||||||
+forward(x, position_ids=None) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class Embedding {
|
|
||||||
+Parameter weight
|
|
||||||
+forward(x) Tensor
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace tokenize {
|
|
||||||
class AutoTokenizer {
|
|
||||||
+vocab_size int
|
|
||||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
|
||||||
+decode(tokens, skip_special_tokens) str
|
|
||||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
|
||||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
|
||||||
+set_chat_template(template)
|
|
||||||
+load(path)
|
|
||||||
+from_pretrained(path) AutoTokenizer
|
|
||||||
+save_pretrained(save_path)
|
|
||||||
}
|
|
||||||
|
|
||||||
class ChatTemplate {
|
|
||||||
+String template_str
|
|
||||||
+render(messages, system_prompt, **extra_variables) str
|
|
||||||
+from_string(template) ChatTemplate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace factory {
|
|
||||||
class Registry {
|
|
||||||
+Dict _entries
|
|
||||||
+register(name, component_cls, category, priority)
|
|
||||||
+get(name) Type
|
|
||||||
+list_names() List[str]
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name, category, priority) decorator
|
|
||||||
+create(name, *args, **kwargs) T
|
|
||||||
+list_registered() list
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace trainer {
|
|
||||||
class Trainer {
|
|
||||||
+TrainConfig train_config
|
|
||||||
+List[TrainCallback] callbacks
|
|
||||||
+train(checkpoint)
|
|
||||||
+_build_context(checkpoint) TrainContext
|
|
||||||
+_get_default_callbacks() List[TrainCallback]
|
|
||||||
}
|
|
||||||
|
|
||||||
class TrainContext {
|
|
||||||
+nn.Module model
|
|
||||||
+BaseStrategy strategy
|
|
||||||
+DataLoader dataloader
|
|
||||||
+Optimizer optimizer
|
|
||||||
+LRScheduler scheduler
|
|
||||||
+Checkpoint checkpoint
|
|
||||||
+int epoch
|
|
||||||
+int iteration
|
|
||||||
+float loss
|
|
||||||
+int world_size
|
|
||||||
+int rank
|
|
||||||
}
|
|
||||||
|
|
||||||
class TrainContextBuilder {
|
|
||||||
+TrainConfig config
|
|
||||||
+with_checkpoint(checkpoint) TrainContextBuilder
|
|
||||||
+build() TrainContext
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseStrategy {
|
|
||||||
+nn.Module model
|
|
||||||
+str device
|
|
||||||
+compute_loss(batch) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class StrategyFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+create(model, train_type, device, **kwargs) BaseStrategy
|
|
||||||
}
|
|
||||||
|
|
||||||
class SEQStrategy {
|
|
||||||
+float label_smoothing
|
|
||||||
+compute_loss(batch) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class SFTStrategy {
|
|
||||||
+float label_smoothing
|
|
||||||
+compute_loss(batch) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class DPOStrategy {
|
|
||||||
+nn.Module ref_model
|
|
||||||
+float beta
|
|
||||||
+str reduction
|
|
||||||
+compute_loss(batch) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class GRPOStrategy {
|
|
||||||
+nn.Module ref_model
|
|
||||||
+float clip_eps
|
|
||||||
+float kl_coef
|
|
||||||
+int group_size
|
|
||||||
+str reduction
|
|
||||||
+int sync_interval
|
|
||||||
+compute_loss(batch) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseScheduler {
|
|
||||||
+get_lr() List[float]
|
|
||||||
+step()
|
|
||||||
}
|
|
||||||
|
|
||||||
class SchedulerFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+create(optimizer, schedule_type, **kwargs) BaseScheduler
|
|
||||||
}
|
|
||||||
|
|
||||||
class CosineScheduler {
|
|
||||||
+int warmup_steps
|
|
||||||
+int lr_decay_steps
|
|
||||||
+float min_rate
|
|
||||||
}
|
|
||||||
|
|
||||||
class SGDRScheduler {
|
|
||||||
+int warmup_steps
|
|
||||||
+int cycle_length
|
|
||||||
+float min_rate
|
|
||||||
+int t_mult
|
|
||||||
}
|
|
||||||
|
|
||||||
class TrainCallback {
|
|
||||||
+on_train_begin(context)
|
|
||||||
+on_train_end(context)
|
|
||||||
+on_epoch_begin(context)
|
|
||||||
+on_epoch_end(context)
|
|
||||||
+on_step_begin(context)
|
|
||||||
+on_step_end(context)
|
|
||||||
+on_batch_begin(context)
|
|
||||||
+on_batch_end(context)
|
|
||||||
+on_error(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class GradientClippingCallback {
|
|
||||||
+float max_grad_norm
|
|
||||||
+on_step_begin(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class CheckpointCallback {
|
|
||||||
+str save_dir
|
|
||||||
+int interval
|
|
||||||
+_save_checkpoint(context)
|
|
||||||
+on_batch_end(context)
|
|
||||||
+on_train_end(context)
|
|
||||||
+on_error(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class ProgressBarCallback {
|
|
||||||
+int num_epoch
|
|
||||||
+on_epoch_begin(context)
|
|
||||||
+on_batch_end(context)
|
|
||||||
+on_epoch_end(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class MetricLoggerCallback {
|
|
||||||
+str log_dir
|
|
||||||
+int save_interval
|
|
||||||
+on_batch_end(context)
|
|
||||||
+on_train_end(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class CallbackFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+create(name, **kwargs) TrainCallback
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace inference {
|
|
||||||
class InferenceEngine {
|
|
||||||
+nn.Module model
|
|
||||||
+AutoTokenizer tokenizer
|
|
||||||
+InferenceScheduler scheduler
|
|
||||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
|
||||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
|
||||||
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
|
||||||
+get_stats() Dict
|
|
||||||
+shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
class InferenceScheduler {
|
|
||||||
+nn.Module model
|
|
||||||
+AutoTokenizer tokenizer
|
|
||||||
+KVCache _page_cache
|
|
||||||
+int max_batch_size
|
|
||||||
+int max_seq_len
|
|
||||||
+int max_prompt_len
|
|
||||||
+int page_size
|
|
||||||
+TaskManager _task_mgr
|
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
|
||||||
+remove_task(task_id)
|
|
||||||
+start()
|
|
||||||
+stop()
|
|
||||||
+get_stats() Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class Allocator {
|
|
||||||
+int _free_mask
|
|
||||||
+int refs_count
|
|
||||||
+LRU _lru
|
|
||||||
+alloc() int
|
|
||||||
+free(idx, keep_cached)
|
|
||||||
+inc_ref(idx)
|
|
||||||
+touch(idx)
|
|
||||||
+ref_count(idx) int
|
|
||||||
}
|
|
||||||
|
|
||||||
class PrefixCache {
|
|
||||||
+int _page_size
|
|
||||||
+evict(page_idx)
|
|
||||||
+has_page(idx) bool
|
|
||||||
+lookup(token_ids) List[int]
|
|
||||||
+record(page_idx, token_ids, logical_page_idx)
|
|
||||||
}
|
|
||||||
|
|
||||||
class PagePool {
|
|
||||||
-Allocator _alloc
|
|
||||||
-PrefixCache _prefix
|
|
||||||
+alloc() int
|
|
||||||
+free(idx)
|
|
||||||
+inc_ref(idx)
|
|
||||||
+lookup(token_ids) List[int]
|
|
||||||
+record(page_idx, token_ids, logical_page_idx)
|
|
||||||
}
|
|
||||||
|
|
||||||
class Storage {
|
|
||||||
+int n_layers
|
|
||||||
+int page_size
|
|
||||||
+int head_dim
|
|
||||||
+int n_kv_heads
|
|
||||||
+Tensor k_cache
|
|
||||||
+Tensor v_cache
|
|
||||||
+write(layer_id, page_table, start_pos, k, v)
|
|
||||||
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
|
|
||||||
}
|
|
||||||
|
|
||||||
class KVCache {
|
|
||||||
-PagePool _pool
|
|
||||||
-Storage _storage
|
|
||||||
-TaskTable _table
|
|
||||||
+int page_size
|
|
||||||
+task_alloc(task_id, prompt_ids) bool
|
|
||||||
+task_free(task_id)
|
|
||||||
+task_extend(task_id, pos) bool
|
|
||||||
+task_cached(task_id) int
|
|
||||||
+task_record_hashes(task_id, prompt_ids, start_logical_page)
|
|
||||||
+make_table_tensor(task_ids, device) Tensor
|
|
||||||
+bind(page_table, total_len) KvcacheView
|
|
||||||
}
|
|
||||||
|
|
||||||
class KvcacheView {
|
|
||||||
-Storage _storage
|
|
||||||
+Tensor _page_table
|
|
||||||
+int _total_len
|
|
||||||
+write(layer_id, k, v)
|
|
||||||
+gather(layer_id) Tuple[Tensor, Tensor]
|
|
||||||
}
|
|
||||||
|
|
||||||
class TaskTable {
|
|
||||||
+set(task_id, page_table, cached)
|
|
||||||
+get(task_id) List[int]
|
|
||||||
+get_cached(task_id) int
|
|
||||||
+get_ref(task_id) List[int]
|
|
||||||
+pop(task_id) Tuple[List[int], int]
|
|
||||||
+table_tensor(task_ids, device) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class Task {
|
|
||||||
+str task_id
|
|
||||||
+List prompt_ids
|
|
||||||
+int max_tokens
|
|
||||||
+float temperature
|
|
||||||
+float top_p
|
|
||||||
+int top_k
|
|
||||||
+TaskStatus status
|
|
||||||
+List output_ids
|
|
||||||
+int input_tokens
|
|
||||||
+int output_tokens
|
|
||||||
+float arrival_time
|
|
||||||
+float finish_time
|
|
||||||
+Callable stream_callback
|
|
||||||
+int next_pos
|
|
||||||
+is_finished(stop_ids) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
class TaskStatus {
|
|
||||||
<<enumeration>>
|
|
||||||
PENDING
|
|
||||||
RUNNING
|
|
||||||
FINISHED
|
|
||||||
ABORTED
|
|
||||||
}
|
|
||||||
|
|
||||||
class GenerationRequest {
|
|
||||||
+List[Dict] messages
|
|
||||||
+int top_k
|
|
||||||
+float top_p
|
|
||||||
+float temperature
|
|
||||||
+Optional[int] max_tokens
|
|
||||||
+bool stream
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseSamplingStrategy {
|
|
||||||
<<abstract>>
|
|
||||||
+apply(logits, filter_value) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class TemperatureStrategy {
|
|
||||||
+float temperature
|
|
||||||
+apply(logits, filter_value) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class TopKStrategy {
|
|
||||||
+int top_k
|
|
||||||
+apply(logits, filter_value) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class TopPStrategy {
|
|
||||||
+float top_p
|
|
||||||
+apply(logits, filter_value) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class SamplingPipeline {
|
|
||||||
+List strategies
|
|
||||||
+apply(logits, filter_value) Tensor
|
|
||||||
+sample(logits, filter_value) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class GenerateResult {
|
|
||||||
+List[Tuple[int, str]] tokens
|
|
||||||
+List[str] results
|
|
||||||
+List[bool] _done
|
|
||||||
+append(token, idx)
|
|
||||||
+get_results() List[str]
|
|
||||||
+pop_all() List[str]
|
|
||||||
+wait(timeout) bool
|
|
||||||
+wait_completion()
|
|
||||||
}
|
|
||||||
|
|
||||||
class ChatMessage {
|
|
||||||
+str role
|
|
||||||
+str content
|
|
||||||
}
|
|
||||||
|
|
||||||
class ChatCompletionRequest {
|
|
||||||
+List[ChatMessage] messages
|
|
||||||
+float temperature
|
|
||||||
+float top_p
|
|
||||||
+int top_k
|
|
||||||
+int max_tokens
|
|
||||||
+bool stream
|
|
||||||
+Optional[str] stop
|
|
||||||
+Optional[int] n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace parallel {
|
|
||||||
class Functions {
|
|
||||||
+spawn_parallel_fn(fn, nprocs)
|
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
|
||||||
+get_current_device() str
|
|
||||||
+get_world_size() int
|
|
||||||
+get_rank() int
|
|
||||||
}
|
|
||||||
|
|
||||||
class ParallelModel {
|
|
||||||
+dist.ProcessGroup process_group
|
|
||||||
+int rank
|
|
||||||
+int world_size
|
|
||||||
}
|
|
||||||
|
|
||||||
class ColumnParallelLinear {
|
|
||||||
+forward(x) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class RowParallelLinear {
|
|
||||||
+forward(x) Tensor
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
%% Relationships
|
|
||||||
TrainConfig --> BaseDataset : uses
|
|
||||||
TrainConfig ..> BaseStrategy : selects
|
|
||||||
StrategyFactory ..> BaseStrategy : creates
|
|
||||||
BaseStrategy <|-- SEQStrategy
|
|
||||||
BaseStrategy <|-- SFTStrategy
|
|
||||||
BaseStrategy <|-- DPOStrategy
|
|
||||||
BaseStrategy <|-- GRPOStrategy
|
|
||||||
DPOStrategy --> Transformer : uses
|
|
||||||
GRPOStrategy --> Transformer : uses
|
|
||||||
Trainer --> TrainConfig : uses
|
|
||||||
Trainer --> TrainContextBuilder : uses
|
|
||||||
Trainer --> TrainCallback : manages
|
|
||||||
TrainContextBuilder --> TrainContext : creates
|
|
||||||
TrainContextBuilder --> StrategyFactory : uses
|
|
||||||
Checkpoint ..> Checkpoint : serializes
|
|
||||||
TrainContext --> Checkpoint : manages
|
|
||||||
TrainContext --> BaseStrategy : uses
|
|
||||||
TrainContext --> BaseScheduler : uses
|
|
||||||
SchedulerFactory ..> BaseScheduler : creates
|
|
||||||
BaseScheduler <|-- CosineScheduler
|
|
||||||
BaseScheduler <|-- SGDRScheduler
|
|
||||||
CallbackFactory ..> TrainCallback : creates
|
|
||||||
TrainCallback <|-- GradientClippingCallback
|
|
||||||
TrainCallback <|-- CheckpointCallback
|
|
||||||
TrainCallback <|-- ProgressBarCallback
|
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
|
||||||
PagePool --> Allocator : composes
|
|
||||||
PagePool --> PrefixCache : composes
|
|
||||||
KVCache --> PagePool : composes
|
|
||||||
KVCache --> Storage : composes
|
|
||||||
KVCache --> TaskTable : composes
|
|
||||||
KvcacheView --> Storage : wraps
|
|
||||||
InferenceEngine --> InferenceScheduler : uses
|
|
||||||
InferenceEngine --> GenerationRequest : uses
|
|
||||||
InferenceEngine --> GenerateResult : creates
|
|
||||||
InferenceScheduler --> Task : manages
|
|
||||||
InferenceScheduler --> TaskStatus : uses
|
|
||||||
InferenceScheduler --> KVCache : uses
|
|
||||||
InferenceScheduler --> Transformer : uses
|
|
||||||
Task --> TaskStatus : uses
|
|
||||||
InferenceEngine --> Transformer : uses
|
|
||||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
|
||||||
SamplingPipeline --> BaseSamplingStrategy : composes
|
|
||||||
BaseDataset <|-- SEQDataset
|
|
||||||
BaseDataset <|-- SFTDataset
|
|
||||||
BaseDataset <|-- DPODataset
|
|
||||||
BaseDataset <|-- GRPODataset
|
|
||||||
DatasetFactory ..> BaseDataset : creates
|
|
||||||
BaseStorage <|-- H5Storage
|
|
||||||
BaseStorage <|-- JSONStorage
|
|
||||||
BaseDataset --> BaseStorage : uses
|
|
||||||
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
|
||||||
AutoModel <|-- Transformer
|
|
||||||
AutoModel --> ModelConfig : contains
|
|
||||||
Transformer --> DecoderBlock : uses
|
|
||||||
Transformer --> RotaryEmbedding : uses
|
|
||||||
Transformer --> Embedding : uses
|
|
||||||
DecoderBlock --> GQA : uses
|
|
||||||
DecoderBlock --> MLP : uses
|
|
||||||
DecoderBlock --> RMSNorm : uses
|
|
||||||
TrainContextBuilder --> ResumableDistributedSampler : creates
|
|
||||||
ResumableDistributedSampler --> BaseDataset : samples
|
|
||||||
ParallelModel <|-- RowParallelLinear
|
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
|
||||||
AutoTokenizer --> ChatTemplate : uses
|
|
||||||
BaseFactory <|-- AutoModel
|
|
||||||
BaseFactory <|-- DatasetFactory
|
|
||||||
BaseFactory <|-- StrategyFactory
|
|
||||||
BaseFactory <|-- SchedulerFactory
|
|
||||||
BaseFactory <|-- CallbackFactory
|
|
||||||
```
|
|
||||||
|
|
||||||
### Module Overview
|
|
||||||
|
|
||||||
| Module | Components | Description |
|
|
||||||
|--------|------------|-------------|
|
|
||||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
|
||||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
|
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
|
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
|
||||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
|
||||||
|
|
||||||
### Design Patterns
|
|
||||||
|
|
||||||
| Pattern | Classes | Purpose |
|
|
||||||
|---------|---------|---------|
|
|
||||||
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
|
||||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
|
||||||
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
|
|
||||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
|
||||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
|
||||||
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
|
|
||||||
|
|
||||||
### Core Relationships
|
|
||||||
|
|
||||||
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
|
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
|
||||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
|
||||||
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
|
||||||
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
|
||||||
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
|
|
||||||
|
|
||||||
## 3. Training Process
|
|
||||||
|
|
||||||
The common training process for large language models (LLM) typically includes three stages: **Pre-training (SEQ)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (DPO/GRPO)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies.
|
|
||||||
|
|
||||||
### Core Formulas
|
|
||||||
|
|
||||||
**Pre-training (SEQ):**
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
|
||||||
$$
|
|
||||||
|
|
||||||
**SFT:**
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{SFT}} = - \sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
|
||||||
$$
|
|
||||||
|
|
||||||
**DPO:**
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]
|
|
||||||
$$
|
|
||||||
|
|
||||||
**GRPO:**
|
|
||||||
|
|
||||||
GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective:
|
|
||||||
|
|
||||||
$$
|
|
||||||
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
|
||||||
$$
|
|
||||||
|
|
||||||
Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards.
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
|
|
||||||
$$
|
|
||||||
|
|
||||||
The KL divergence term uses mean squared error approximation:
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right]
|
|
||||||
$$
|
|
||||||
|
|
||||||
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
|
|
||||||
|
|
||||||
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
|
||||||
|
|
@ -0,0 +1,152 @@
|
||||||
|
# Inference
|
||||||
|
|
||||||
|
## KV Cache
|
||||||
|
|
||||||
|
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
|
||||||
|
|
||||||
|
$$
|
||||||
|
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
|
||||||
|
$$
|
||||||
|
|
||||||
|
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
|
||||||
|
|
||||||
|
## KVCache System
|
||||||
|
|
||||||
|
Six classes (plus two helpers) working together:
|
||||||
|
|
||||||
|
```
|
||||||
|
KVCache (facade)
|
||||||
|
├── PagePool orchestrates page allocation + prefix matching
|
||||||
|
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
||||||
|
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
||||||
|
├── TaskTable maps task_id → page_table + cached token count
|
||||||
|
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
||||||
|
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
||||||
|
```
|
||||||
|
|
||||||
|
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
||||||
|
|
||||||
|
## Continuous Batching
|
||||||
|
|
||||||
|
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Cleanup → Remove finished tasks, free KV pages
|
||||||
|
2. Refill → Pop from waiting_queue, task_alloc pages, activate
|
||||||
|
3. Prefill → Group by (prompt_len, start_pos), run full forward
|
||||||
|
4. Decode → Pick largest same-position group, single-token forward
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sampling (Strategy Pattern)
|
||||||
|
|
||||||
|
```
|
||||||
|
BaseSamplingStrategy (ABC)
|
||||||
|
├── TemperatureStrategy
|
||||||
|
├── TopKStrategy
|
||||||
|
├── TopPStrategy
|
||||||
|
└── SamplingPipeline
|
||||||
|
```
|
||||||
|
|
||||||
|
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||||
|
`sample()` is a convenience shortcut for one-shot usage.
|
||||||
|
|
||||||
|
## Protocol Handlers (Strategy Pattern)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ProtocolHandler: # concrete orchestrator
|
||||||
|
def __init__(self, request, engine, builder): ...
|
||||||
|
async def handle(self):
|
||||||
|
prompt, ctx, stops = builder.prepare(request, engine)
|
||||||
|
agen = engine.generate_async(prompt, ...)
|
||||||
|
if stream: self._handle_stream(agen, ctx, stops)
|
||||||
|
else: return await self._handle_non_stream(agen, ctx, stops)
|
||||||
|
```
|
||||||
|
|
||||||
|
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||||
|
|
||||||
|
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
||||||
|
|
||||||
|
Adding a protocol = one builder file, no handler subclassing needed.
|
||||||
|
|
||||||
|
## Engine & GenerateResult
|
||||||
|
|
||||||
|
```
|
||||||
|
InferenceEngine
|
||||||
|
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
||||||
|
├── generate_with_request(req) → same
|
||||||
|
├── generate_async(prompt, ...) → AsyncGenerator
|
||||||
|
├── get_stats() → Dict
|
||||||
|
└── shutdown()
|
||||||
|
```
|
||||||
|
|
||||||
|
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
||||||
|
|
||||||
|
## HTTP API
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /v1/chat/completions OpenAI
|
||||||
|
POST /v1/messages Anthropic
|
||||||
|
GET /health {"status":"ok","model_loaded":true}
|
||||||
|
GET /stats scheduler statistics
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-abc123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1717000000,
|
||||||
|
"model": "astrai",
|
||||||
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
||||||
|
|
||||||
|
### Anthropic
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||||
|
|
||||||
|
### GenerationRequest Parameters
|
||||||
|
|
||||||
|
| Param | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||||
|
| `top_k` | int | 50 | Top-k count |
|
||||||
|
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||||
|
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
|
||||||
|
| `max_tokens` | Optional[int] | None | Max generation length |
|
||||||
|
| `stream` | bool | False | Stream output |
|
||||||
|
|
||||||
|
## Engine API
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Non-streaming
|
||||||
|
engine.generate("Hello", stream=False) # -> str
|
||||||
|
engine.generate(["A", "B"], stream=False) # -> List[str]
|
||||||
|
|
||||||
|
# Streaming
|
||||||
|
engine.generate("Hello", stream=True) # -> Generator[str]
|
||||||
|
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
|
|
||||||
|
# Async
|
||||||
|
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
|
||||||
|
print(token)
|
||||||
|
```
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-30
|
||||||
|
|
@ -1,334 +0,0 @@
|
||||||
## Model Introduction
|
|
||||||
|
|
||||||
### 1. Model Architecture
|
|
||||||
|
|
||||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
|
||||||
|
|
||||||
The model now uses the **AutoModel** base class for flexible loading and saving:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
|
|
||||||
# Load model from checkpoint
|
|
||||||
model = AutoModel.from_pretrained("path/to/model")
|
|
||||||
|
|
||||||
# Save model to new directory
|
|
||||||
model.save_pretrained("path/to/save")
|
|
||||||
```
|
|
||||||
|
|
||||||
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types.
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
flowchart TB
|
|
||||||
subgraph Layers["Transformer Layers"]
|
|
||||||
direction TB
|
|
||||||
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
|
||||||
B --> C[Transformer Block\nLayer ...]
|
|
||||||
C --> D[Transformer Block\nLayer ...]
|
|
||||||
D --> E[RMSNorm]
|
|
||||||
E --> F[Linear]
|
|
||||||
F --> G[SoftMax]
|
|
||||||
end
|
|
||||||
|
|
||||||
subgraph TransformerBlock["Transformer Block"]
|
|
||||||
direction TB
|
|
||||||
H[x] --> I[RMSNorm]
|
|
||||||
I --> J[Linear → Q/K/V]
|
|
||||||
J --> K[Q]
|
|
||||||
J --> L[K]
|
|
||||||
J --> M[V]
|
|
||||||
K --> N[RoPE]
|
|
||||||
L --> O[RoPE]
|
|
||||||
N --> P["Q @ K^T / sqrt(d)"]
|
|
||||||
O --> P
|
|
||||||
P --> Q[Masked SoftMax]
|
|
||||||
Q --> R[S @ V]
|
|
||||||
M --> R
|
|
||||||
R --> S[Linear]
|
|
||||||
S --> T[+]
|
|
||||||
H --> T
|
|
||||||
T --> U[RMSNorm]
|
|
||||||
U --> V["Linear (gate)"]
|
|
||||||
U --> W["Linear (up)"]
|
|
||||||
V --> X[SiLU]
|
|
||||||
X --> Y[×]
|
|
||||||
W --> Y
|
|
||||||
Y --> Z["Linear (down)"]
|
|
||||||
Z --> AA[+]
|
|
||||||
T --> AA
|
|
||||||
AA --> BB[x']
|
|
||||||
end
|
|
||||||
|
|
||||||
classDef main fill:#e6f3ff,stroke:#0066cc;
|
|
||||||
classDef block fill:#fff2e6,stroke:#cc6600;
|
|
||||||
class Layers main;
|
|
||||||
class TransformerBlock block;
|
|
||||||
```
|
|
||||||
|
|
||||||
What is an autoregressive model? After splitting a sentence into tokens, the model predicts the probability distribution of the next token. This means the model calculates the probability of the next possible token and its corresponding probability based on the given context (the sequence of tokens that have already appeared).
|
|
||||||
|
|
||||||
#### 1. Autoregression
|
|
||||||
|
|
||||||
In autoregressive modeling, when a sentence is tokenized into a sequence of tokens, the model learns to predict what comes next. Given a sequence of tokens as input, the model calculates a probability distribution over all possible next tokens. This distribution tells us how likely each potential next token is, given the current context.
|
|
||||||
|
|
||||||
For instance, if the input sequence contains tokens representing a question, the model might predict that certain response tokens have higher probabilities than others. The sampling process then selects one token from this distribution—controlled by parameters like top_k, top_p, and temperature—to serve as the next token in the sequence.
|
|
||||||
|
|
||||||
Once a token is selected, it is appended to the input sequence, and the model repeats this process. The updated sequence is then fed back into the model to predict the next token. This iterative process continues until either a special end-of-sequence token is generated, or the maximum sequence length is reached. These control tokens are essential because without them, the model would continue generating tokens indefinitely, eventually exhausting available memory.
|
|
||||||
|
|
||||||
#### 2. Causal Mask
|
|
||||||
|
|
||||||
Transformers use attention mechanism. The input shape is generally [bsz, seq_len], and the output is [bsz, seq_len, n_dim]. To predict the next token, the model's input and output must be offset by one position. The target predicted by the model must be offset by one position, and during training we also use the offset-by-one method:
|
|
||||||
|
|
||||||
```
|
|
||||||
sequence : [[1, 2, 3, 4, 5, 6]]
|
|
||||||
input_ids: [[1, 2, 3, 4, 5]]
|
|
||||||
target_ids: [[2, 3, 4, 5, 6]]
|
|
||||||
```
|
|
||||||
|
|
||||||
The attention score calculation formula is:
|
|
||||||
|
|
||||||
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
|
|
||||||
$$ s_{ij} := s_{ij} + mask_{ij} $$
|
|
||||||
|
|
||||||
Here, the attention score represents the degree to which the model attends to the similarity between two tokens.
|
|
||||||
|
|
||||||
For decoder-only structure models, to prevent the model from "stealing" information from future positions, a mask needs to be added during attention calculation. We need to apply a mask before attention score calculation. This mask is typically a lower triangular matrix, and for a sequence of length n, its shape is [n, n]. Below is an example of how to create such a causal mask matrix for a sequence of length 5:
|
|
||||||
|
|
||||||
```
|
|
||||||
[[0, -inf, -inf, -inf, -inf],
|
|
||||||
[0, 0, -inf, -inf, -inf],
|
|
||||||
[0, 0, 0, -inf, -inf],
|
|
||||||
[0, 0, 0, 0, -inf],
|
|
||||||
[0, 0, 0, 0, 0]]
|
|
||||||
```
|
|
||||||
|
|
||||||
In this matrix, 0 represents positions that can be attended to, while -inf represents positions that should be masked (i.e., should not be attended to). Because this matrix ensures that after the softmax, the parts of the attention scores where $j > i$ change from `inf` to 0, meaning the model cannot see future information.
|
|
||||||
|
|
||||||
#### 3. Rotary Position Embedding
|
|
||||||
|
|
||||||
Rotary Position Embedding (RoPE) is a position encoding method designed to solve the problem of lacking direct modeling of sequence position information in Transformer models. Unlike traditional position encodings (such as sine and cosine function position encodings), RoPE embeds position information directly into the Query (Q) and Key (K) vectors, allowing the model to more naturally handle relative position relationships in sequences.
|
|
||||||
|
|
||||||
$$ q_i = R_i W_q x_i $$
|
|
||||||
$$ k_j = R_j W_k x_j $$
|
|
||||||
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
|
|
||||||
|
|
||||||
The $R_{i-j}$ controls the attenuation of attention for different tokens at different relative distances. When the absolute value of $i - j$ is larger, the degree of attenuation is stronger. This approach allows the model to learn relative position relationships, enabling the model to scale and adapt to longer sequences.
|
|
||||||
|
|
||||||
## KV Cache Implementation
|
|
||||||
|
|
||||||
According to the attention calculation formula:
|
|
||||||
|
|
||||||
$$
|
|
||||||
\begin{align*}
|
|
||||||
o_i &= \sum_j s_{ij} v_{j} \newline
|
|
||||||
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
|
|
||||||
\end{align*}
|
|
||||||
$$
|
|
||||||
|
|
||||||
Since the model is an autoregressive model, we only need to calculate for the last part of the sequence, meaning the index $i$ is fixed as the last element of the sequence, and we compute $o_{n}$:
|
|
||||||
|
|
||||||
$$
|
|
||||||
\begin{align*}
|
|
||||||
o_n &= \sum_j s_{j}v_{j} \newline
|
|
||||||
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
|
|
||||||
\end{align*}
|
|
||||||
$$
|
|
||||||
|
|
||||||
If we expand the expression:
|
|
||||||
|
|
||||||
$$
|
|
||||||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
|
||||||
$$
|
|
||||||
|
|
||||||
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
|
|
||||||
|
|
||||||
### 4. AutoModel Loading
|
|
||||||
|
|
||||||
The project now uses the **AutoModel** base class for flexible model loading and saving:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
|
|
||||||
# Load model from checkpoint
|
|
||||||
model = AutoModel.from_pretrained("path/to/model")
|
|
||||||
|
|
||||||
# Save model to new directory
|
|
||||||
model.save_pretrained("path/to/save")
|
|
||||||
```
|
|
||||||
|
|
||||||
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights.
|
|
||||||
|
|
||||||
### 5. Continuous Batching Inference
|
|
||||||
|
|
||||||
The inference engine supports **continuous batching** for efficient batch processing:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.inference import InferenceEngine, GenerationRequest
|
|
||||||
|
|
||||||
# Create inference engine with continuous batching
|
|
||||||
engine = InferenceEngine(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use GenerationRequest with messages format
|
|
||||||
request = GenerationRequest(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
],
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50,
|
|
||||||
max_tokens=None,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate with streaming
|
|
||||||
for token in engine.generate_with_request(request):
|
|
||||||
print(token, end="", flush=True)
|
|
||||||
```
|
|
||||||
|
|
||||||
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
|
|
||||||
|
|
||||||
## HTTP API Usage
|
|
||||||
|
|
||||||
The inference server provides HTTP endpoints for remote inference. Start the server first:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m scripts.tools.server --port 8000
|
|
||||||
```
|
|
||||||
|
|
||||||
### OpenAI-Compatible Endpoint
|
|
||||||
|
|
||||||
The server provides an OpenAI-compatible chat completion endpoint at `/v1/chat/completions`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Hello, how are you?"}
|
|
||||||
],
|
|
||||||
"temperature": 0.8,
|
|
||||||
"max_tokens": 2048,
|
|
||||||
"stream": false
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Request Parameters:**
|
|
||||||
| Parameter | Type | Default | Description |
|
|
||||||
|-----------|------|---------|-------------|
|
|
||||||
| `messages` | List[dict] | Required | Chat messages with role and content |
|
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
|
|
||||||
| `top_p` | float | 1.0 | Nucleus sampling threshold |
|
|
||||||
| `top_k` | int | 50 | Top-k sampling parameter |
|
|
||||||
| `max_tokens` | int | 1024 | Maximum tokens to generate |
|
|
||||||
| `stream` | bool | false | Enable streaming response |
|
|
||||||
|
|
||||||
**Response (non-streaming):**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"id": "chatcmpl-1234567890",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "astrai",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 20,
|
|
||||||
"completion_tokens": 15,
|
|
||||||
"total_tokens": 35
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Streaming Response
|
|
||||||
|
|
||||||
Enable streaming for real-time token-by-token output:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [{"role": "user", "content": "Write a story"}],
|
|
||||||
"stream": true,
|
|
||||||
"max_tokens": 500
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
|
||||||
|
|
||||||
### Anthropic-Compatible Endpoint
|
|
||||||
|
|
||||||
The server also provides an Anthropic-compatible endpoint at `/v1/messages`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"system": "You are a helpful assistant.",
|
|
||||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
|
||||||
"max_tokens": 2048
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Response:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"id": "msg_abc123...",
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": "astrai",
|
|
||||||
"content": [{"type": "text", "text": "Hello! I am doing well..."}],
|
|
||||||
"stop_reason": "end_turn",
|
|
||||||
"stop_sequence": null,
|
|
||||||
"usage": {"input_tokens": 20, "output_tokens": 15}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Streaming:
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"system": "You are a helpful assistant.",
|
|
||||||
"messages": [{"role": "user", "content": "Write a short poem"}],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"stream": true
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Supports `stop_sequences` for early termination:
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"messages": [{"role": "user", "content": "Write a story"}],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"stop_sequences": ["The end", "THE END"]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
### Health Check
|
|
||||||
|
|
||||||
Monitor server and model status:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://localhost:8000/health
|
|
||||||
# {"status": "ok", "model_loaded": true}
|
|
||||||
|
|
||||||
curl http://localhost:8000/stats
|
|
||||||
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
|
|
||||||
```
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
|
||||||
|
|
@ -10,14 +10,14 @@
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
| `--batch_size` | Batch size | 1 |
|
| `--batch_per_device` | Batch size per device | 1 |
|
||||||
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
### Learning Rate Scheduling
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--warmup_steps` | Warmup steps | 1000 |
|
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
||||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||||
|
|
||||||
|
|
@ -53,14 +53,16 @@
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||||
|
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
||||||
| `--device_type` | Device type | cuda |
|
| `--device_type` | Device type | cuda |
|
||||||
|
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||||
|
|
||||||
### Strategy-specific
|
### Strategy-specific
|
||||||
|
|
||||||
| Parameter | Description | Default | Used by |
|
| Parameter | Description | Default | Used by |
|
||||||
|-----------|-------------|---------|---------|
|
|-----------|-------------|---------|---------|
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
||||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||||
|
|
@ -69,90 +71,30 @@
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type seq \
|
|
||||||
--data_root_path /path/to/dataset \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path /path/to/model \
|
--nprocs=4 \
|
||||||
--n_epoch 3 \
|
--parallel_mode=ddp \
|
||||||
--batch_size 4 \
|
--train_type=seq \
|
||||||
--accumulation_steps 8 \
|
--data_root_path=/path/to/dataset \
|
||||||
--max_lr 3e-4 \
|
--param_path=/path/to/model \
|
||||||
--warmup_steps 2000 \
|
--batch_per_device=4 \
|
||||||
--max_grad_norm 1.0 \
|
--grad_accum_steps=8 \
|
||||||
--ckpt_interval 5000 \
|
--warmup_ratio=0.05 \
|
||||||
--ckpt_dir ./checkpoints \
|
--max_lr=1e-4 \
|
||||||
--num_workers 4 \
|
--max_grad_norm=1.0 \
|
||||||
--nprocs 1 \
|
--adamw_beta1=0.9 \
|
||||||
--device_type cuda
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Generation Parameters
|
> Document Update Time: 2026-05-24
|
||||||
|
|
||||||
### GenerationRequest Parameters
|
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
|
||||||
|-----------|-------------|---------------|
|
|
||||||
| `messages` | List of message dictionaries (role, content) | required |
|
|
||||||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
|
||||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
|
||||||
| `top_k` | Top-k sampling count | 50 |
|
|
||||||
| `max_tokens` | Maximum generation length | None (unlimited) |
|
|
||||||
| `stream` | Whether to stream output | False |
|
|
||||||
|
|
||||||
### Usage Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
from astrai.inference import InferenceEngine, GenerationRequest
|
|
||||||
|
|
||||||
# Load model using AutoModel
|
|
||||||
model = AutoModel.from_pretrained("your_model_dir")
|
|
||||||
|
|
||||||
# Load tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
|
||||||
|
|
||||||
# Create engine with separate model and tokenizer
|
|
||||||
engine = InferenceEngine(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build request with messages format
|
|
||||||
request = GenerationRequest(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
],
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50,
|
|
||||||
max_tokens=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate (streaming)
|
|
||||||
for token in engine.generate_with_request(request):
|
|
||||||
print(token, end="", flush=True)
|
|
||||||
|
|
||||||
# Or use simple generate interface
|
|
||||||
result = engine.generate(
|
|
||||||
prompt="Hello",
|
|
||||||
stream=False,
|
|
||||||
max_tokens=1024,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Generation Modes
|
|
||||||
|
|
||||||
| Mode | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `stream=True` | Streaming output, yields token by token |
|
|
||||||
| `stream=False` | Non-streaming output, returns complete result |
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
|
||||||
|
|
@ -0,0 +1,346 @@
|
||||||
|
# Preprocessing Pipeline
|
||||||
|
|
||||||
|
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
|
||||||
|
|
||||||
|
## Philosophy
|
||||||
|
|
||||||
|
| Component | Responsibility |
|
||||||
|
|-----------|---------------|
|
||||||
|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
||||||
|
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
||||||
|
|
||||||
|
A single config file captures the entire pipeline, reusable and version-controllable.
|
||||||
|
|
||||||
|
## Config Structure
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {}, // sections (single) or sources (multi)
|
||||||
|
"mask": {}, // role → "train" | "mask"
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {},
|
||||||
|
"output": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Section Fields
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `field` | str | -- | JSONL key to read |
|
||||||
|
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
|
||||||
|
| `template` | bool | `false` | Apply `chat_template` per message |
|
||||||
|
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
|
||||||
|
|
||||||
|
### Source Fields (multi-output mode)
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `sections` | list[dict] | -- | Same as single-output section list |
|
||||||
|
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
|
||||||
|
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### SFT Chat
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "messages", "action": "$role", "template": true}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mask": {
|
||||||
|
"system": "mask",
|
||||||
|
"user": "mask",
|
||||||
|
"assistant": "train"
|
||||||
|
},
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {
|
||||||
|
"max_seq_len": 2048
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"storage_format": "bin",
|
||||||
|
"dtype": {"loss_mask": "bool"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Output keys: `sequence` (int32), `loss_mask` (bool)
|
||||||
|
|
||||||
|
### SFT Instruction
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
||||||
|
{"field": "response", "action": "train"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {
|
||||||
|
"max_seq_len": 2048
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Output keys: `sequence`, `loss_mask`
|
||||||
|
|
||||||
|
### Pretrain
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"text": "Artificial Intelligence is a field of computer science..."}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "text", "action": "train"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"preprocessing": {
|
||||||
|
"max_seq_len": 8192,
|
||||||
|
"min_chars": 100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Output keys: `sequence` (no `loss_mask` — all tokens trained)
|
||||||
|
|
||||||
|
### DPO
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"sources": {
|
||||||
|
"chosen": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "chosen", "action": "$role", "template": true}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"rejected": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "rejected", "action": "$role", "template": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"mask": {
|
||||||
|
"user": "mask",
|
||||||
|
"assistant": "train"
|
||||||
|
},
|
||||||
|
"mask_default": "mask"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
|
||||||
|
|
||||||
|
### GRPO
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"sources": {
|
||||||
|
"prompts": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "prompt", "action": "mask", "template": true}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "responses", "action": "train"}
|
||||||
|
],
|
||||||
|
"list_field": true,
|
||||||
|
"mask_key": "masks"
|
||||||
|
},
|
||||||
|
"rewards": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "rewards", "action": "value"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"mask": {
|
||||||
|
"user": "mask",
|
||||||
|
"assistant": "train"
|
||||||
|
},
|
||||||
|
"mask_default": "mask"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
|
||||||
|
|
||||||
|
- `action: "value"` — extract raw values from JSONL without tokenisation
|
||||||
|
- `list_field: true` — tokenise each list element independently, then concatenate
|
||||||
|
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration Reference
|
||||||
|
|
||||||
|
### `input`
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
|
||||||
|
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
|
||||||
|
|
||||||
|
When `sources` is set, `sections` is ignored.
|
||||||
|
|
||||||
|
### `mask`
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
|
||||||
|
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
|
||||||
|
|
||||||
|
### `preprocessing`
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
|
||||||
|
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
|
||||||
|
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
|
||||||
|
| `max_items` | int or null | `null` | Stop after N documents |
|
||||||
|
|
||||||
|
### `output`
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
|
||||||
|
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
|
||||||
|
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
|
||||||
|
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Mask Algorithm
|
||||||
|
|
||||||
|
### Template mode (`template: true`)
|
||||||
|
|
||||||
|
For each message in the field's array:
|
||||||
|
|
||||||
|
1. Prepend BOS token (masked)
|
||||||
|
2. Render through `chat_template` for that single message
|
||||||
|
3. Encode rendered text
|
||||||
|
4. Apply mask rule for the message's role
|
||||||
|
|
||||||
|
### Non-template mode
|
||||||
|
|
||||||
|
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
|
||||||
|
|
||||||
|
### Text config detection
|
||||||
|
|
||||||
|
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Output Layout
|
||||||
|
|
||||||
|
### Single-Shard (`bin`)
|
||||||
|
|
||||||
|
```
|
||||||
|
output/
|
||||||
|
__default__/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
wiki/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Shard (`bin`)
|
||||||
|
|
||||||
|
When `max_tokens_per_shard` is exceeded:
|
||||||
|
|
||||||
|
```
|
||||||
|
output/
|
||||||
|
__default__/
|
||||||
|
shard_0000/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
shard_0001/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CLI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# SFT
|
||||||
|
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
|
||||||
|
|
||||||
|
# DPO
|
||||||
|
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
|
||||||
|
|
||||||
|
# GRPO
|
||||||
|
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Python API
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline
|
||||||
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
|
|
||||||
|
config = PipelineConfig.from_json("sft.json")
|
||||||
|
Pipeline(
|
||||||
|
config,
|
||||||
|
["data_part1.jsonl", "data_part2.jsonl"],
|
||||||
|
output_dir="output/",
|
||||||
|
tokenizer_path="params",
|
||||||
|
).run()
|
||||||
|
```
|
||||||
|
|
||||||
|
> Document Update Time: 2026-06-03
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
# Training
|
||||||
|
|
||||||
|
### Autoregression
|
||||||
|
|
||||||
|
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||||||
|
|
||||||
|
### Causal Mask
|
||||||
|
|
||||||
|
```
|
||||||
|
sequence : [[1, 2, 3, 4, 5, 6]]
|
||||||
|
input_ids: [[1, 2, 3, 4, 5]]
|
||||||
|
target_ids: [[2, 3, 4, 5, 6]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Lower-triangular mask prevents attending to future positions:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[0, -inf, -inf, -inf, -inf],
|
||||||
|
[0, 0, -inf, -inf, -inf],
|
||||||
|
[0, 0, 0, -inf, -inf],
|
||||||
|
[0, 0, 0, 0, -inf],
|
||||||
|
[0, 0, 0, 0, 0]]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rotary Position Embedding (RoPE)
|
||||||
|
|
||||||
|
RoPE embeds position into Q/K vectors via complex rotation:
|
||||||
|
|
||||||
|
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||||||
|
|
||||||
|
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
|
||||||
|
|
||||||
|
## Training Loop
|
||||||
|
|
||||||
|
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
||||||
|
|
||||||
|
```
|
||||||
|
on_train_begin
|
||||||
|
model.train()
|
||||||
|
on_epoch_begin
|
||||||
|
for batch in dataloader:
|
||||||
|
on_batch_begin
|
||||||
|
with executor.accumulate(model):
|
||||||
|
loss = strategy.compute_loss(batch)
|
||||||
|
context.loss = loss.item()
|
||||||
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
|
executor.backward(stand_loss)
|
||||||
|
context.iteration += 1
|
||||||
|
on_batch_end
|
||||||
|
|
||||||
|
if executor.sync_gradients:
|
||||||
|
on_optimizer_step
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if scheduler:
|
||||||
|
scheduler.step()
|
||||||
|
on_epoch_end
|
||||||
|
on_train_end
|
||||||
|
```
|
||||||
|
|
||||||
|
### Callback Lifecycle
|
||||||
|
|
||||||
|
| Hook | Fires | Default callback |
|
||||||
|
|------|-------|-----------------|
|
||||||
|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||||
|
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
|
||||||
|
| `on_batch_begin` | Every batch | — |
|
||||||
|
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
||||||
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
|
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
|
||||||
|
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
|
||||||
|
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
|
||||||
|
|
||||||
|
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||||
|
|
||||||
|
## Strategies
|
||||||
|
|
||||||
|
### SEQ (Pre-training)
|
||||||
|
|
||||||
|
Next-token cross-entropy with optional label smoothing:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
||||||
|
$$
|
||||||
|
|
||||||
|
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
|
||||||
|
|
||||||
|
### SFT (Supervised Fine-Tuning)
|
||||||
|
|
||||||
|
Masked cross-entropy (`ignore_index=-100`) over response tokens:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
||||||
|
$$
|
||||||
|
|
||||||
|
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
|
||||||
|
|
||||||
|
### DPO (Direct Preference Optimization)
|
||||||
|
|
||||||
|
Frozen reference model, preference margin via log-ratio:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
||||||
|
$$
|
||||||
|
|
||||||
|
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||||||
|
|
||||||
|
### GRPO (Group Relative Policy Optimization)
|
||||||
|
|
||||||
|
On-policy PPO with group-normalized advantages:
|
||||||
|
|
||||||
|
$$
|
||||||
|
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
||||||
|
$$
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
||||||
|
$$
|
||||||
|
|
||||||
|
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
|
||||||
|
|
||||||
|
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
|
|
||||||
|
## LR Schedulers
|
||||||
|
|
||||||
|
| Type | Class | Description |
|
||||||
|
|------|-------|-------------|
|
||||||
|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||||
|
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||||
|
|
||||||
|
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
||||||
|
|
||||||
|
## Gradient Checkpointing
|
||||||
|
|
||||||
|
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
||||||
|
```
|
||||||
|
|
||||||
|
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
||||||
|
|
||||||
|
## Checkpoint
|
||||||
|
|
||||||
|
```
|
||||||
|
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
||||||
|
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
||||||
|
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
|
||||||
|
```
|
||||||
|
|
||||||
|
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||||
|
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
||||||
|
|
||||||
|
## TrainContextBuilder (Builder Pattern)
|
||||||
|
|
||||||
|
```python
|
||||||
|
context = (
|
||||||
|
TrainContextBuilder(config)
|
||||||
|
.with_resume_dir(resume_dir)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
- Loads checkpoint weights if provided
|
||||||
|
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||||
|
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||||
|
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||||
|
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
|
||||||
|
|
||||||
|
## Training CLI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
|
nohup python scripts/tools/train.py \
|
||||||
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
|
--train_type=seq \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/model \
|
||||||
|
--batch_per_device=4 \
|
||||||
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.9 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
|
```
|
||||||
|
|
||||||
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-30
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
__version__ = "1.3.5"
|
__version__ = "1.3.7"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
ModelConfig,
|
AutoRegressiveLMConfig,
|
||||||
|
EncoderConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
|
|
@ -11,13 +12,14 @@ from astrai.inference import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
from astrai.model import AutoModel, Transformer
|
from astrai.model import AutoModel, AutoRegressiveLM
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Transformer",
|
"AutoRegressiveLM",
|
||||||
"ModelConfig",
|
"AutoRegressiveLMConfig",
|
||||||
|
"EncoderConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
"AutoTokenizer",
|
"AutoTokenizer",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,25 @@
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import (
|
||||||
|
AutoRegressiveLMConfig,
|
||||||
|
BaseModelConfig,
|
||||||
|
ConfigFactory,
|
||||||
|
EncoderConfig,
|
||||||
|
)
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
OutputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Model configuration
|
"BaseModelConfig",
|
||||||
"ModelConfig",
|
"AutoRegressiveLMConfig",
|
||||||
|
"EncoderConfig",
|
||||||
|
"ConfigFactory",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
"InputConfig",
|
||||||
|
"OutputConfig",
|
||||||
|
"PipelineConfig",
|
||||||
|
"ProcessingConfig",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
import json
|
||||||
|
from dataclasses import MISSING, dataclass, fields
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseConfig:
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
d = {}
|
||||||
|
for fld in fields(self):
|
||||||
|
v = getattr(self, fld.name)
|
||||||
|
if isinstance(v, (str, int, float, bool)):
|
||||||
|
d[fld.name] = v
|
||||||
|
elif v is None:
|
||||||
|
d[fld.name] = None
|
||||||
|
elif isinstance(v, (dict, list, tuple)):
|
||||||
|
try:
|
||||||
|
val = list(v) if isinstance(v, tuple) else v
|
||||||
|
json.dumps(val)
|
||||||
|
d[fld.name] = val
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
elif isinstance(v, BaseConfig):
|
||||||
|
d[fld.name] = v.to_dict()
|
||||||
|
elif hasattr(v, "__dataclass_fields__"):
|
||||||
|
sub = {}
|
||||||
|
for f in fields(v):
|
||||||
|
a = getattr(v, f.name)
|
||||||
|
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
||||||
|
d[fld.name] = sub
|
||||||
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: Dict[str, Any]) -> Self:
|
||||||
|
hints = get_type_hints(cls)
|
||||||
|
inst = cls.__new__(cls)
|
||||||
|
for fld in fields(cls):
|
||||||
|
if fld.name in d:
|
||||||
|
v = d[fld.name]
|
||||||
|
target = cls._unwrap_optional(hints.get(fld.name))
|
||||||
|
if target is not None:
|
||||||
|
try:
|
||||||
|
v = cls._coerce(v, target)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
object.__setattr__(inst, fld.name, v)
|
||||||
|
elif fld.default is not MISSING:
|
||||||
|
object.__setattr__(inst, fld.name, fld.default)
|
||||||
|
elif fld.default_factory is not MISSING:
|
||||||
|
object.__setattr__(inst, fld.name, fld.default_factory())
|
||||||
|
else:
|
||||||
|
object.__setattr__(inst, fld.name, None)
|
||||||
|
return inst
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_optional(tp) -> Optional[type]:
|
||||||
|
if tp is None:
|
||||||
|
return None
|
||||||
|
origin = getattr(tp, "__origin__", None)
|
||||||
|
if origin is not None:
|
||||||
|
args = getattr(tp, "__args__", ())
|
||||||
|
non_none = [a for a in args if a is not type(None)]
|
||||||
|
return non_none[0] if non_none else None
|
||||||
|
return tp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce(value: Any, target_type: type) -> Any:
|
||||||
|
if target_type is bool and isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if (
|
||||||
|
target_type is int
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return int(value)
|
||||||
|
if (
|
||||||
|
target_type is float
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return float(value)
|
||||||
|
if target_type is str and isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if isinstance(value, target_type):
|
||||||
|
return value
|
||||||
|
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
||||||
|
return target_type.from_dict(value)
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, path: Union[str, Path]) -> Self:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return cls.from_dict(json.load(f))
|
||||||
|
|
||||||
|
def to_json(self, path: Union[str, Path]):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||||
|
|
@ -1,42 +1,92 @@
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Self
|
from typing import Any, Dict, Optional, Self
|
||||||
|
|
||||||
|
from astrai.config.base import BaseConfig
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigFactory(BaseFactory[BaseConfig]):
|
||||||
|
"""Factory that dispatches config classes by ``model_type``."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
||||||
|
model_type = raw.get("model_type") or "autoregressive_lm"
|
||||||
|
config_cls = cls.get_component_class(model_type)
|
||||||
|
return config_cls.from_dict(raw)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class BaseModelConfig(BaseConfig):
|
||||||
# basic config
|
"""Base config with ``model_type`` dispatch and file I/O."""
|
||||||
|
|
||||||
model_type: Optional[str] = None
|
model_type: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, config_path: str) -> Self:
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
raw: Dict[str, Any] = json.load(f)
|
||||||
|
return cls.from_dict(raw)
|
||||||
|
|
||||||
|
def to_file(self, config_path: str):
|
||||||
|
d = self.to_dict()
|
||||||
|
config_dict = {k: v for k, v in d.items() if v is not None}
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ConfigFactory.register("autoregressive_lm")
|
||||||
|
class AutoRegressiveLMConfig(BaseModelConfig):
|
||||||
|
"""Configuration for autoregressive language model."""
|
||||||
|
|
||||||
vocab_size: Optional[int] = None
|
vocab_size: Optional[int] = None
|
||||||
dim: Optional[int] = None
|
dim: Optional[int] = None
|
||||||
|
|
||||||
n_layers: Optional[int] = None
|
n_layers: Optional[int] = None
|
||||||
norm_eps: Optional[float] = None
|
norm_eps: Optional[float] = None
|
||||||
dim_ffn: Optional[int] = None
|
dim_ffn: Optional[int] = None
|
||||||
tie_weight: Optional[bool] = None
|
tie_weight: Optional[bool] = None
|
||||||
|
|
||||||
# RoPE
|
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
|
||||||
# GQA
|
attn_type: str = "gqa"
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
use_qk_norm: Optional[bool] = None
|
use_qk_norm: Optional[bool] = None
|
||||||
use_gated_attention: Optional[bool] = None
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
kv_lora_rank: Optional[int] = None
|
||||||
config = {}
|
qk_nope_head_dim: Optional[int] = None
|
||||||
with open(config_path, "r") as f:
|
qk_rope_head_dim: Optional[int] = None
|
||||||
config.update(json.load(f))
|
|
||||||
|
|
||||||
for key, value in config.items():
|
ffn_type: str = "mlp"
|
||||||
if hasattr(self, key):
|
n_routed_experts: Optional[int] = None
|
||||||
setattr(self, key, value)
|
n_shared_experts: Optional[int] = None
|
||||||
|
n_activated_experts: Optional[int] = None
|
||||||
|
topk_method: Optional[str] = None
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def save(self, config_path: str):
|
@dataclass
|
||||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
@ConfigFactory.register("embedding")
|
||||||
with open(config_path, "w") as f:
|
class EncoderConfig(BaseModelConfig):
|
||||||
json.dump(config_dict, f, indent=4)
|
"""Configuration for embedding encoder model."""
|
||||||
|
|
||||||
|
vocab_size: Optional[int] = None
|
||||||
|
dim: Optional[int] = None
|
||||||
|
n_layers: Optional[int] = None
|
||||||
|
norm_eps: Optional[float] = None
|
||||||
|
dim_ffn: Optional[int] = None
|
||||||
|
|
||||||
|
max_len: Optional[int] = None
|
||||||
|
rope_theta: Optional[float] = None
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
|
||||||
|
n_heads: Optional[int] = None
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
use_qk_norm: Optional[bool] = None
|
||||||
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
pooling_type: Optional[str] = None
|
||||||
|
normalize_embeddings: Optional[bool] = None
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""Pipeline configuration for JSONL preprocessing.
|
||||||
|
|
||||||
|
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
|
||||||
|
modes, both driven declaratively through ``input.sections`` or
|
||||||
|
``input.sources``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from astrai.config.base import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InputConfig(BaseConfig):
|
||||||
|
"""Declarative input mapping.
|
||||||
|
|
||||||
|
Single-output mode (backward-compatible)::
|
||||||
|
|
||||||
|
{"input": {"sections": [{"field": "messages", ...}]}}
|
||||||
|
|
||||||
|
Multi-output mode (DPO / GRPO)::
|
||||||
|
|
||||||
|
{"input": {"sources": {
|
||||||
|
"chosen": {"sections": [{"field": "chosen", ...}]},
|
||||||
|
"rejected": {"sections": [{"field": "rejected", ...}]},
|
||||||
|
}}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
sections: Optional[List[Dict]] = None
|
||||||
|
sources: Optional[Dict[str, Dict]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessingConfig(BaseConfig):
|
||||||
|
"""Processing configuration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_seq_len : int
|
||||||
|
Maximum sequence length (default: 2048).
|
||||||
|
min_chars : int
|
||||||
|
Minimum number of characters to keep (default: 50).
|
||||||
|
max_chars : int
|
||||||
|
Maximum number of characters to keep (default: 2_000_000).
|
||||||
|
max_items : Optional[int]
|
||||||
|
Maximum number of items to process (default: None, unlimited).
|
||||||
|
packing_strategy : str
|
||||||
|
How to pack sequences into a contiguous stream.
|
||||||
|
|
||||||
|
- ``"simple"``: sequential concatenation (default, backward compatible).
|
||||||
|
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
|
||||||
|
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
|
||||||
|
max_packed_len : int
|
||||||
|
Maximum length of a packed bin. Sequences longer than this are
|
||||||
|
truncated or split depending on ``packing_strategy`` (default: 8192).
|
||||||
|
truncation_mode : str
|
||||||
|
How to truncate sequences longer than ``max_packed_len``.
|
||||||
|
|
||||||
|
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
|
||||||
|
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_seq_len: int = 2048
|
||||||
|
min_chars: int = 50
|
||||||
|
max_chars: int = 2_000_000
|
||||||
|
max_items: Optional[int] = None
|
||||||
|
packing_strategy: str = "simple"
|
||||||
|
max_packed_len: int = 8192
|
||||||
|
truncation_mode: str = "keep_start"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutputConfig(BaseConfig):
|
||||||
|
"""Output configuration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
domain_key : Optional[str]
|
||||||
|
Domain key for the output store (default: None).
|
||||||
|
storage_format : str
|
||||||
|
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
|
||||||
|
max_tokens_per_shard : int
|
||||||
|
Maximum tokens per shard before splitting (default: 100_000_000).
|
||||||
|
dtype : Dict[str, str]
|
||||||
|
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
|
||||||
|
position_ids_mode : Optional[str]
|
||||||
|
How to compute position_ids in packed sequences.
|
||||||
|
|
||||||
|
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||||
|
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||||
|
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||||
|
"""
|
||||||
|
|
||||||
|
domain_key: Optional[str] = None
|
||||||
|
storage_format: str = "bin"
|
||||||
|
max_tokens_per_shard: int = 100_000_000
|
||||||
|
dtype: Dict[str, str] = field(default_factory=dict)
|
||||||
|
position_ids_mode: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineConfig(BaseConfig):
|
||||||
|
version: int = 1
|
||||||
|
input: InputConfig = field(default_factory=InputConfig)
|
||||||
|
mask: Dict[str, str] = field(default_factory=dict)
|
||||||
|
mask_default: str = "mask"
|
||||||
|
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
||||||
|
output: OutputConfig = field(default_factory=OutputConfig)
|
||||||
|
|
@ -1,32 +1,49 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Callable, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from astrai.config.base import BaseConfig
|
||||||
|
from astrai.model.components.lora import LoRAConfig
|
||||||
|
|
||||||
|
|
||||||
|
def required(**kw):
|
||||||
|
return {"required": True, **kw}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig(BaseConfig):
|
||||||
# basic setting
|
# basic setting
|
||||||
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
model_fn: Callable[[], nn.Module] = field(
|
||||||
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
default=None, metadata=required(help="Model factory for training.")
|
||||||
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
|
)
|
||||||
|
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||||
|
dataset: Dataset = field(
|
||||||
|
default=None, metadata=required(help="Dataset for training.")
|
||||||
|
)
|
||||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||||
default=None, metadata={"help": "Optimizer factory for training."}
|
default=None, metadata=required(help="Optimizer factory for training.")
|
||||||
)
|
)
|
||||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||||
default=None, metadata={"help": "Scheduler factory for training."}
|
default=None, metadata=required(help="Scheduler factory for training.")
|
||||||
)
|
)
|
||||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||||
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
batch_per_device: int = field(
|
||||||
accumulation_steps: int = field(
|
default=4, metadata={"help": "Batch size per device."}
|
||||||
|
)
|
||||||
|
grad_accum_steps: int = field(
|
||||||
default=1, metadata={"help": "Number of iterations between steps."}
|
default=1, metadata={"help": "Number of iterations between steps."}
|
||||||
)
|
)
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||||
)
|
)
|
||||||
|
gradient_checkpointing_modules: list = field(
|
||||||
|
default_factory=list,
|
||||||
|
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||||
|
)
|
||||||
|
|
||||||
# checkpoint setting
|
# checkpoint setting
|
||||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||||
|
|
@ -40,6 +57,25 @@ class TrainConfig:
|
||||||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# lora setting
|
||||||
|
lora: Optional[LoRAConfig] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "LoRA config. None means full fine-tuning."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# metric setting
|
||||||
|
log_dir: str = field(
|
||||||
|
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
||||||
|
)
|
||||||
|
log_interval: int = field(
|
||||||
|
default=100,
|
||||||
|
metadata={"help": "Number of batch iterations between metric logs."},
|
||||||
|
)
|
||||||
|
metrics: List[str] = field(
|
||||||
|
default_factory=lambda: ["loss", "lr"],
|
||||||
|
metadata={"help": "Metrics to record during training."},
|
||||||
|
)
|
||||||
|
|
||||||
# dataloader setting
|
# dataloader setting
|
||||||
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||||
num_workers: int = field(
|
num_workers: int = field(
|
||||||
|
|
@ -66,17 +102,37 @@ class TrainConfig:
|
||||||
master_port: str = field(
|
master_port: str = field(
|
||||||
default="29500", metadata={"help": "Master port for distributed training."}
|
default="29500", metadata={"help": "Master port for distributed training."}
|
||||||
)
|
)
|
||||||
parallel_wrapper: Optional[Callable] = field(
|
parallel_mode: str = field(
|
||||||
default=None, metadata={"help": "Parallel function for training."}
|
default="none",
|
||||||
|
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
||||||
)
|
)
|
||||||
state_dict_fn: Optional[Callable] = field(
|
start_method: str = field(
|
||||||
default=None, metadata={"help": "Parallel function for state dict saving."}
|
default="spawn",
|
||||||
|
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
|
||||||
)
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
device_type: str = field(
|
device_type: str = field(
|
||||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||||
)
|
)
|
||||||
|
val_dataset: Optional[Dataset] = field(
|
||||||
|
default=None, metadata={"help": "Dataset for validation."}
|
||||||
|
)
|
||||||
|
val_split: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
val_step: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||||
|
)
|
||||||
|
|
||||||
|
executor_kwargs: dict = field(
|
||||||
|
default_factory=dict,
|
||||||
|
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||||
|
)
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict, metadata={"help": "Other arguments."}
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
)
|
)
|
||||||
|
|
@ -85,14 +141,6 @@ class TrainConfig:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
required_fields = [
|
for fld in fields(self):
|
||||||
"model",
|
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
||||||
"strategy",
|
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|
||||||
"dataset",
|
|
||||||
"optimizer_fn",
|
|
||||||
"scheduler_fn",
|
|
||||||
]
|
|
||||||
|
|
||||||
for field_name in required_fields:
|
|
||||||
if getattr(self, field_name) is None:
|
|
||||||
raise ValueError(f"{field_name} is required.")
|
|
||||||
|
|
|
||||||
|
|
@ -4,34 +4,28 @@ from astrai.dataset.dataset import (
|
||||||
)
|
)
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseSegmentFetcher,
|
H5Store,
|
||||||
BaseStorage,
|
MmapStore,
|
||||||
H5Storage,
|
Store,
|
||||||
JSONStorage,
|
StoreFactory,
|
||||||
MultiSegmentFetcher,
|
|
||||||
available_storage_types,
|
|
||||||
create_storage,
|
|
||||||
detect_format,
|
detect_format,
|
||||||
|
load_bin,
|
||||||
load_h5,
|
load_h5,
|
||||||
load_json,
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
save_json,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
"BaseSegmentFetcher",
|
"Store",
|
||||||
"MultiSegmentFetcher",
|
"StoreFactory",
|
||||||
"BaseStorage",
|
"H5Store",
|
||||||
"H5Storage",
|
"MmapStore",
|
||||||
"JSONStorage",
|
|
||||||
"create_storage",
|
|
||||||
"detect_format",
|
"detect_format",
|
||||||
"available_storage_types",
|
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
"save_json",
|
"save_bin",
|
||||||
"load_json",
|
"load_bin",
|
||||||
"ResumableDistributedSampler",
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseStorage,
|
Store,
|
||||||
create_storage,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -26,33 +26,47 @@ class BaseDataset(Dataset, ABC):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.storage: Optional[BaseStorage] = None
|
self.storage: Optional[Store] = None
|
||||||
|
|
||||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
"""Return required storage keys for this dataset type.
|
||||||
|
|
||||||
|
Subclasses should override to specify expected keys.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _validate_keys(self):
|
||||||
|
if not self.required_keys:
|
||||||
|
return
|
||||||
|
actual_keys = set(self.storage.keys)
|
||||||
|
missing = [k for k in self.required_keys if k not in actual_keys]
|
||||||
|
if missing:
|
||||||
|
raise KeyError(
|
||||||
|
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
|
||||||
|
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
|
||||||
|
f"Missing: {missing}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, load_path: str, storage_type: Optional[str] = None):
|
||||||
"""Load dataset from the given path.
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
Auto-detects the storage format if not specified.
|
Auto-detects the storage format if not specified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
load_path: Path to the data directory or file
|
load_path: Path to the data directory or file
|
||||||
storage_type: Force a specific storage type ("h5", "json"),
|
storage_type: Force a specific storage type ("h5", "bin"),
|
||||||
or None for auto-detection
|
or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
|
||||||
in JSON files. Ignored for HDF5.
|
Raises:
|
||||||
|
KeyError: If the loaded storage is missing required keys.
|
||||||
"""
|
"""
|
||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = create_storage(storage_type)
|
self.storage = StoreFactory.create(storage_type)
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self._load_path = load_path
|
||||||
|
self.storage.load(load_path)
|
||||||
def load_json(self, load_path: str, tokenizer=None):
|
self._validate_keys()
|
||||||
"""Load dataset from JSON files explicitly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
load_path: Path to the JSON data file or directory
|
|
||||||
tokenizer: Optional tokenizer callable for raw text JSON.
|
|
||||||
"""
|
|
||||||
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
|
|
@ -123,7 +137,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, dataset_cls: type) -> None:
|
def _validate_component(cls, dataset_cls: type):
|
||||||
"""Validate that the dataset class inherits from BaseDataset."""
|
"""Validate that the dataset class inherits from BaseDataset."""
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
|
|
@ -150,7 +164,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: Optional[int] = None,
|
stride: Optional[int] = None,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
tokenizer=None,
|
|
||||||
) -> "BaseDataset":
|
) -> "BaseDataset":
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
|
|
@ -159,8 +172,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
load_path: Path to the data file
|
load_path: Path to the data file
|
||||||
window_size: Window size for data sampling
|
window_size: Window size for data sampling
|
||||||
stride: Stride between consecutive samples (default: same as window_size)
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
storage_type: Storage type ("h5", "json") or None for auto-detection
|
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded dataset instance
|
Loaded dataset instance
|
||||||
|
|
@ -169,7 +181,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
stride = window_size
|
stride = window_size
|
||||||
|
|
||||||
dataset = cls.create(train_type, window_size, stride)
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
dataset.load(load_path, storage_type=storage_type)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
@ -186,6 +198,10 @@ class SEQDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["sequence"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
|
|
@ -205,21 +221,27 @@ class SFTDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["sequence", "loss_mask", "position_ids"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence")
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
|
||||||
dtype=torch.long
|
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
|
||||||
)
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
return {
|
||||||
|
"input_ids": x.to(dtype=torch.long),
|
||||||
|
"target_ids": y.to(dtype=torch.long),
|
||||||
|
"position_ids": position_ids.to(dtype=torch.long),
|
||||||
|
"loss_mask": loss_mask.to(dtype=torch.bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("dpo")
|
@DatasetFactory.register("dpo")
|
||||||
|
|
@ -229,6 +251,10 @@ class DPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
|
@ -259,15 +285,21 @@ class GRPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["prompts", "responses", "masks", "rewards"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
|
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
||||||
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
|
self.iter = self.iter % self.num_samples_per_replica
|
||||||
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
|
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _remaining(self):
|
||||||
|
remaining = self.num_samples_per_replica - self.iter
|
||||||
|
return max(remaining, 0)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples_per_replica
|
return self._remaining
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,37 @@
|
||||||
"""Storage backends for different data formats.
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
Layers:
|
||||||
a uniform interface for data access and length observation via fetchers.
|
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
||||||
|
return Dict[str, List[Tensor]] — format-specific, no state
|
||||||
|
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||||
|
Dict[str, List[Tensor]] per key via _normalize(),
|
||||||
|
fetch() uses bisect across segments — no forced concat
|
||||||
|
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||||
|
|
||||||
|
Key properties:
|
||||||
|
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||||
|
datasets larger than RAM
|
||||||
|
- Explicit length: _length = min(total elements across keys), set at load,
|
||||||
|
__len__ returns O(1)
|
||||||
|
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||||
|
workers share OS page-cache pages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
|
@ -52,54 +69,30 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
return tensor_group
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
meta = {}
|
||||||
json_data = {}
|
|
||||||
for key, tensors in tensor_group.items():
|
for key, tensors in tensor_group.items():
|
||||||
json_data[key] = [tensor.tolist() for tensor in tensors]
|
cat = torch.cat(tensors, dim=0)
|
||||||
with open(full_file_path, "w", encoding="utf-8") as f:
|
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||||
json.dump(json_data, f, ensure_ascii=False)
|
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||||
|
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
||||||
|
json.dump(meta, f)
|
||||||
|
|
||||||
|
|
||||||
def load_json(
|
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||||
file_path: str,
|
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
||||||
share_memory: bool = True,
|
meta = json.load(f)
|
||||||
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
segments: Dict[str, List[Tensor]] = {}
|
||||||
) -> Dict[str, List[Tensor]]:
|
for key, info in meta.items():
|
||||||
"""Load tensor data from JSON files.
|
arr = np.memmap(
|
||||||
|
os.path.join(file_path, f"{key}.bin"),
|
||||||
Supports two modes:
|
dtype=info["dtype"],
|
||||||
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
mode="r+",
|
||||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
shape=tuple(info["shape"]),
|
||||||
at load time. A ``tokenizer`` receives a str and returns List[int].
|
)
|
||||||
|
segments[key] = [torch.from_numpy(arr)]
|
||||||
Non-data JSON files (e.g. config.json) with scalar/object values are
|
return segments
|
||||||
silently skipped.
|
|
||||||
"""
|
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
|
||||||
root_path = Path(file_path)
|
|
||||||
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
|
||||||
for json_file in json_files:
|
|
||||||
with open(json_file, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
continue
|
|
||||||
for key, sequences in data.items():
|
|
||||||
if not isinstance(sequences, list):
|
|
||||||
continue
|
|
||||||
tensors = []
|
|
||||||
for seq in sequences:
|
|
||||||
if tokenizer is not None and isinstance(seq, str):
|
|
||||||
seq = tokenizer(seq)
|
|
||||||
tensor = torch.tensor(seq, dtype=torch.long)
|
|
||||||
if share_memory:
|
|
||||||
tensor = tensor.share_memory_()
|
|
||||||
tensors.append(tensor)
|
|
||||||
if tensor_group.get(key) is None:
|
|
||||||
tensor_group[key] = []
|
|
||||||
tensor_group[key].extend(tensors)
|
|
||||||
return tensor_group
|
|
||||||
|
|
||||||
|
|
||||||
def detect_format(load_path: str) -> str:
|
def detect_format(load_path: str) -> str:
|
||||||
|
|
@ -109,7 +102,7 @@ def detect_format(load_path: str) -> str:
|
||||||
load_path: Directory or file path
|
load_path: Directory or file path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Format string ("h5" or "json")
|
Format string ("h5" or "bin")
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If no supported data files are found
|
FileNotFoundError: If no supported data files are found
|
||||||
|
|
@ -119,194 +112,160 @@ def detect_format(load_path: str) -> str:
|
||||||
suffix = root.suffix.lower()
|
suffix = root.suffix.lower()
|
||||||
if suffix in (".h5", ".hdf5"):
|
if suffix in (".h5", ".hdf5"):
|
||||||
return "h5"
|
return "h5"
|
||||||
if suffix in (".json", ".jsonl"):
|
|
||||||
return "json"
|
|
||||||
raise ValueError(f"Unsupported file format: {suffix}")
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
h5_files = [
|
||||||
|
Path(p)
|
||||||
|
for pattern in ("*.h5", "*.hdf5")
|
||||||
|
for p in glob.glob(str(root / "**" / pattern), recursive=True)
|
||||||
|
]
|
||||||
if h5_files:
|
if h5_files:
|
||||||
return "h5"
|
return "h5"
|
||||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
|
||||||
if json_files:
|
if bin_files:
|
||||||
return "json"
|
has_meta = (root / "meta.json").exists() or len(
|
||||||
|
[Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
|
||||||
|
) > 0
|
||||||
|
if has_meta:
|
||||||
|
return "bin"
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class Store(ABC):
|
||||||
"""Fetches data segments across multiple tensor segments.
|
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
||||||
|
|
||||||
Maintains cumulative lengths for efficient range queries across
|
Each key maps to one or more tensor segments (no forced concatenation).
|
||||||
multiple discontinuous segments.
|
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
||||||
"""
|
total element count across all keys.
|
||||||
|
|
||||||
def __init__(self, segments: List[Tensor]):
|
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
||||||
self.segments = segments
|
via ``_normalize()``.
|
||||||
self.cum_lengths = []
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
for seg in segments:
|
|
||||||
total += torch.numel(seg)
|
|
||||||
self.cum_lengths.append(total)
|
|
||||||
|
|
||||||
self.total_length = total
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.total_length
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
||||||
"""Fetch data in the range [begin_idx, end_idx)."""
|
|
||||||
if not (
|
|
||||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
|
||||||
):
|
|
||||||
raise ValueError("begin_idx or end_idx out of bounds")
|
|
||||||
if begin_idx >= end_idx:
|
|
||||||
return torch.tensor([], dtype=torch.long)
|
|
||||||
|
|
||||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
|
||||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
|
||||||
|
|
||||||
result_segments = []
|
|
||||||
|
|
||||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
|
||||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
|
||||||
start = max(begin_idx - prev_cum, 0)
|
|
||||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
|
||||||
result_segments.append(self.segments[i][start:end])
|
|
||||||
|
|
||||||
return torch.cat(result_segments, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiSegmentFetcher:
|
|
||||||
"""Manages multiple segment fetchers for different data keys."""
|
|
||||||
|
|
||||||
def __init__(self, multi_segments: Dict):
|
|
||||||
self.multi_keys = list(multi_segments.keys())
|
|
||||||
self.multi_fetchers = {
|
|
||||||
key: BaseSegmentFetcher(segments)
|
|
||||||
for key, segments in multi_segments.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Returns the minimum length across all fetchers."""
|
|
||||||
if not self.multi_fetchers:
|
|
||||||
return 0
|
|
||||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
|
||||||
return min(len_list)
|
|
||||||
|
|
||||||
def key_fetch(
|
|
||||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
|
||||||
) -> Dict:
|
|
||||||
"""Fetch data for specific keys."""
|
|
||||||
fetch_dict = {}
|
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
fetcher = self.multi_fetchers[key]
|
|
||||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
|
||||||
fetch_dict[key] = fetch_tensor
|
|
||||||
|
|
||||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
|
||||||
"""Fetch all keys."""
|
|
||||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStorage(ABC):
|
|
||||||
"""Abstract storage backend for loading and dispatching data.
|
|
||||||
|
|
||||||
Storage encapsulates format-specific loading and provides a uniform
|
|
||||||
interface for data access and length observation. Subclasses handle
|
|
||||||
different data formats (HDF5, JSON, etc.) while exposing the same
|
|
||||||
fetch interface.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
self._data: Dict[str, List[Tensor]] = {}
|
||||||
|
self._cum: Dict[str, List[int]] = {}
|
||||||
|
self._length: int = 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, path: str) -> None:
|
||||||
"""Load data from the given path into internal fetcher."""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Total number of raw elements (tokens) in storage."""
|
|
||||||
if self._fetcher is None:
|
|
||||||
return 0
|
|
||||||
return len(self._fetcher)
|
|
||||||
|
|
||||||
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
|
||||||
"""Fetch data for the given keys and index range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
begin_idx: Starting index (inclusive)
|
|
||||||
end_idx: Ending index (exclusive)
|
|
||||||
keys: Single key or list of keys to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor if single key, Dict[str, Tensor] if multiple keys
|
|
||||||
"""
|
|
||||||
if self._fetcher is None:
|
|
||||||
raise RuntimeError("Storage not loaded")
|
|
||||||
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self) -> List[str]:
|
def keys(self) -> List[str]:
|
||||||
"""Return the data keys available in this storage."""
|
return list(self._data.keys())
|
||||||
if self._fetcher is None:
|
|
||||||
return []
|
def __len__(self) -> int:
|
||||||
return self._fetcher.multi_keys
|
return self._length
|
||||||
|
|
||||||
|
def fetch(
|
||||||
|
self,
|
||||||
|
begin: int,
|
||||||
|
end: int,
|
||||||
|
keys: Union[str, List[str]],
|
||||||
|
):
|
||||||
|
if not self._data:
|
||||||
|
raise RuntimeError("Store not loaded")
|
||||||
|
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
||||||
|
raise ValueError(
|
||||||
|
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
||||||
|
)
|
||||||
|
if isinstance(keys, str):
|
||||||
|
return self._fetch_key(keys, begin, end)
|
||||||
|
return {k: self._fetch_key(k, begin, end) for k in keys}
|
||||||
|
|
||||||
|
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
||||||
|
"""Fetch slice [begin, end) across potentially multiple segments."""
|
||||||
|
segments = self._data[key]
|
||||||
|
cum = self._cum[key]
|
||||||
|
seg_start = bisect.bisect_right(cum, begin)
|
||||||
|
seg_end = bisect.bisect_left(cum, end)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(seg_start, seg_end + 1):
|
||||||
|
prev = cum[i - 1] if i > 0 else 0
|
||||||
|
s = max(begin - prev, 0)
|
||||||
|
e = min(end - prev, segments[i].shape[0])
|
||||||
|
results.append(segments[i][s:e])
|
||||||
|
|
||||||
|
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
||||||
|
|
||||||
|
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
||||||
|
"""Register segments and pre-compute cumulative lengths.
|
||||||
|
|
||||||
|
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
||||||
|
large datasets. Sets ``self._length`` to the minimum total
|
||||||
|
element count across all keys.
|
||||||
|
"""
|
||||||
|
for key, tensors in raw.items():
|
||||||
|
self._data[key] = tensors
|
||||||
|
cum = []
|
||||||
|
total = 0
|
||||||
|
for t in tensors:
|
||||||
|
total += t.shape[0]
|
||||||
|
cum.append(total)
|
||||||
|
self._cum[key] = cum
|
||||||
|
self._length = (
|
||||||
|
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
||||||
|
if self._cum
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class H5Storage(BaseStorage):
|
class StoreFactory(BaseFactory["Store"]):
|
||||||
|
"""Factory for creating Store instances by type name.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
@StoreFactory.register("custom")
|
||||||
|
class CustomStore(Store):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, store_cls: type):
|
||||||
|
if not issubclass(store_cls, Store):
|
||||||
|
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
||||||
|
|
||||||
|
|
||||||
|
@StoreFactory.register("h5")
|
||||||
|
class H5Store(Store):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, path: str):
|
||||||
segments = load_h5(load_path)
|
self._normalize(load_h5(path))
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
|
||||||
|
|
||||||
|
|
||||||
class JSONStorage(BaseStorage):
|
@StoreFactory.register("bin")
|
||||||
"""JSON-based storage backend.
|
class MmapStore(Store):
|
||||||
|
"""Memory-mapped binary storage backend.
|
||||||
|
|
||||||
Supports two modes:
|
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
||||||
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
No per-process memory duplication — all DataLoader workers share the
|
||||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
same OS page-cache pages.
|
||||||
callable (str -> List[int]) at load time.
|
|
||||||
|
Format on disk::
|
||||||
|
|
||||||
|
data_root/
|
||||||
|
meta.json # {key: {shape, dtype}, ...}
|
||||||
|
<key>.bin # raw numpy array, one per key
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, path: str):
|
||||||
segments = load_json(load_path, tokenizer=tokenizer)
|
self._mmap_refs = []
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
root = Path(path)
|
||||||
|
all_raw: Dict[str, List[Tensor]] = {}
|
||||||
|
meta_paths = [
|
||||||
_STORAGE_REGISTRY: Dict[str, type] = {
|
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
|
||||||
"h5": H5Storage,
|
]
|
||||||
"json": JSONStorage,
|
for meta_path in meta_paths:
|
||||||
}
|
raw = load_bin(str(meta_path.parent))
|
||||||
|
for key, tensors in raw.items():
|
||||||
|
if key not in all_raw:
|
||||||
def create_storage(storage_type: str) -> BaseStorage:
|
all_raw[key] = []
|
||||||
"""Create a storage instance by type name.
|
all_raw[key].extend(tensors)
|
||||||
|
if not meta_paths:
|
||||||
Args:
|
raise FileNotFoundError(f"No meta.json found under {path}")
|
||||||
storage_type: Storage type name ("h5", "json")
|
self._normalize(all_raw)
|
||||||
|
for tensors in self._data.values():
|
||||||
Returns:
|
self._mmap_refs.extend(tensors)
|
||||||
Storage instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the storage type is unknown
|
|
||||||
"""
|
|
||||||
storage_cls = _STORAGE_REGISTRY.get(storage_type)
|
|
||||||
if storage_cls is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown storage type: '{storage_type}'. "
|
|
||||||
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
|
|
||||||
)
|
|
||||||
return storage_cls()
|
|
||||||
|
|
||||||
|
|
||||||
def available_storage_types() -> List[str]:
|
|
||||||
"""Return list of registered storage type names."""
|
|
||||||
return sorted(_STORAGE_REGISTRY.keys())
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Base factory class for extensible component registration."""
|
"""Base factory class for extensible component registration."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
|
@ -22,7 +23,7 @@ class Registry:
|
||||||
component_cls: Type,
|
component_cls: Type,
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> None:
|
):
|
||||||
"""Register a component class with optional category and priority."""
|
"""Register a component class with optional category and priority."""
|
||||||
if name in self._entries:
|
if name in self._entries:
|
||||||
raise ValueError(f"Component '{name}' is already registered")
|
raise ValueError(f"Component '{name}' is already registered")
|
||||||
|
|
@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
def create(cls, name: str, *args, **kwargs) -> T:
|
def create(cls, name: str, *args, **kwargs) -> T:
|
||||||
"""Create a component instance by name.
|
"""Create a component instance by name.
|
||||||
|
|
||||||
|
Filters kwargs to match the component's __init__ signature,
|
||||||
|
so components don't need to declare **kwargs just to absorb
|
||||||
|
parameters meant for other components.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Registered name of the component
|
name: Registered name of the component
|
||||||
*args: Positional arguments passed to component constructor
|
*args: Positional arguments passed to component constructor
|
||||||
|
|
@ -139,10 +144,21 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||||
)
|
)
|
||||||
component_cls = cls._registry.get(name)
|
component_cls = cls._registry.get(name)
|
||||||
|
sig = inspect.signature(component_cls.__init__)
|
||||||
|
has_var_kwargs = any(
|
||||||
|
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||||
|
)
|
||||||
|
if not has_var_kwargs:
|
||||||
|
valid = {
|
||||||
|
p.name
|
||||||
|
for p in sig.parameters.values()
|
||||||
|
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
||||||
|
}
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
||||||
return component_cls(*args, **kwargs)
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
def _validate_component(cls, component_cls: Type[T]):
|
||||||
"""Validate that the component class is valid for this factory.
|
"""Validate that the component class is valid for this factory.
|
||||||
|
|
||||||
Override this method in subclasses to add custom validation.
|
Override this method in subclasses to add custom validation.
|
||||||
|
|
|
||||||
|
|
@ -2,24 +2,26 @@
|
||||||
|
|
||||||
Layers:
|
Layers:
|
||||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
- api/: HTTP orchestration (ProtocolHandler, server)
|
||||||
|
- protocols/: Response builders (OpenAI, Anthropic)
|
||||||
|
- transport/: SSE transport utilities
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from astrai.inference.api import (
|
from astrai.inference.api import (
|
||||||
AnthropicHandler,
|
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
GenContext,
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
OpenAIHandler,
|
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
StopChecker,
|
StopChecker,
|
||||||
StreamContext,
|
get_app,
|
||||||
app,
|
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||||
from astrai.inference.core import (
|
from astrai.inference.core import (
|
||||||
STOP,
|
STOP,
|
||||||
Allocator,
|
Allocator,
|
||||||
|
|
@ -36,10 +38,7 @@ from astrai.inference.core import (
|
||||||
TaskTable,
|
TaskTable,
|
||||||
page_hash,
|
page_hash,
|
||||||
)
|
)
|
||||||
from astrai.inference.engine import (
|
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||||
GenerationRequest,
|
|
||||||
InferenceEngine,
|
|
||||||
)
|
|
||||||
from astrai.inference.sample import (
|
from astrai.inference.sample import (
|
||||||
BaseSamplingStrategy,
|
BaseSamplingStrategy,
|
||||||
SamplingPipeline,
|
SamplingPipeline,
|
||||||
|
|
@ -50,17 +49,14 @@ from astrai.inference.sample import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine / Requests
|
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
# Core scheduler
|
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Executor",
|
"Executor",
|
||||||
"STOP",
|
"STOP",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskManager",
|
"TaskManager",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
# Core cache
|
|
||||||
"Allocator",
|
"Allocator",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"KvcacheView",
|
"KvcacheView",
|
||||||
|
|
@ -69,24 +65,21 @@ __all__ = [
|
||||||
"Storage",
|
"Storage",
|
||||||
"TaskTable",
|
"TaskTable",
|
||||||
"page_hash",
|
"page_hash",
|
||||||
# Sampling (Strategy pattern)
|
|
||||||
"sample",
|
"sample",
|
||||||
"BaseSamplingStrategy",
|
"BaseSamplingStrategy",
|
||||||
"TemperatureStrategy",
|
"TemperatureStrategy",
|
||||||
"TopKStrategy",
|
"TopKStrategy",
|
||||||
"TopPStrategy",
|
"TopPStrategy",
|
||||||
"SamplingPipeline",
|
"SamplingPipeline",
|
||||||
# Protocol
|
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"StreamContext",
|
"GenContext",
|
||||||
"AnthropicHandler",
|
"OpenAIResponseBuilder",
|
||||||
"OpenAIHandler",
|
"AnthropicResponseBuilder",
|
||||||
# Server
|
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"app",
|
"get_app",
|
||||||
"run_server",
|
"run_server",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,27 @@
|
||||||
"""Inference API: protocol handlers and FastAPI server."""
|
"""Inference API: protocol handler, stop checker, and FastAPI server.
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
||||||
AnthropicHandler,
|
lazy singleton FastAPI instance.
|
||||||
OpenAIHandler,
|
"""
|
||||||
ProtocolHandler,
|
|
||||||
StopChecker,
|
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||||
StreamContext,
|
|
||||||
)
|
|
||||||
from astrai.inference.api.server import (
|
from astrai.inference.api.server import (
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
app,
|
get_app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnthropicHandler",
|
|
||||||
"OpenAIHandler",
|
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"StreamContext",
|
"GenContext",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"app",
|
"get_app",
|
||||||
"run_server",
|
"run_server",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""Anthropic message completion response builder."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import (
|
||||||
|
GenContext,
|
||||||
|
ResponseBuilder,
|
||||||
|
StopInfo,
|
||||||
|
sse_event,
|
||||||
|
)
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
messages: List[Dict[str, str]] = []
|
||||||
|
system = getattr(request, "system", None)
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
for m in request.messages:
|
||||||
|
text = _extract_text(m.content)
|
||||||
|
if text:
|
||||||
|
messages.append({"role": m.role, "content": text})
|
||||||
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
ctx = GenContext(
|
||||||
|
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||||
|
created=int(time.time()),
|
||||||
|
model=request.model,
|
||||||
|
prompt_tokens=0,
|
||||||
|
)
|
||||||
|
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||||
|
return prompt, ctx, stop_sequences
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [],
|
||||||
|
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
event="message_start",
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": {"type": "text", "text": ""},
|
||||||
|
},
|
||||||
|
event="content_block_start",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
return sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": token},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
events: List[str] = []
|
||||||
|
if stop.matched:
|
||||||
|
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
||||||
|
unyielded = trimmed[len(stop.yielded) :]
|
||||||
|
if unyielded:
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": unyielded},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{"type": "content_block_stop", "index": 0},
|
||||||
|
event="content_block_stop",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": {
|
||||||
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||||
|
"stop_sequence": stop.matched,
|
||||||
|
},
|
||||||
|
"usage": {"output_tokens": ctx.completion_tokens},
|
||||||
|
},
|
||||||
|
event="message_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if stop.matched:
|
||||||
|
content = content[: content.rfind(stop.matched)]
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [{"type": "text", "text": content}],
|
||||||
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||||
|
"stop_sequence": stop.matched,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": ctx.prompt_tokens,
|
||||||
|
"output_tokens": ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
"""OpenAI chat completion response builder."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import (
|
||||||
|
GenContext,
|
||||||
|
ResponseBuilder,
|
||||||
|
StopInfo,
|
||||||
|
sse_event,
|
||||||
|
)
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_UNSUPPORTED_PARAMS = (
|
||||||
|
"n",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"logit_bias",
|
||||||
|
"user",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
self._model = request.model
|
||||||
|
|
||||||
|
for param in _UNSUPPORTED_PARAMS:
|
||||||
|
value = getattr(request, param, None)
|
||||||
|
fields = getattr(type(request), "model_fields", {})
|
||||||
|
default = fields[param].default if param in fields else None
|
||||||
|
if value is not None and value != default:
|
||||||
|
logger.warning(
|
||||||
|
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||||
|
param,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
if value is not None and value != default:
|
||||||
|
logger.warning(
|
||||||
|
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||||
|
param,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = GenContext(
|
||||||
|
resp_id=self._resp_id,
|
||||||
|
created=int(time.time()),
|
||||||
|
model=self._model,
|
||||||
|
prompt_tokens=0,
|
||||||
|
)
|
||||||
|
stop = request.stop
|
||||||
|
stop_sequences = (
|
||||||
|
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||||
|
)
|
||||||
|
return prompt, ctx, stop_sequences
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant"},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
return sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 0,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -1,15 +1,13 @@
|
||||||
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
||||||
|
|
||||||
Template Method + Builder patterns eliminate the 45% code duplication between
|
ProtocolHandler orchestrates the async generation loop and delegates
|
||||||
stream/non-stream branches and across protocol adapters.
|
protocol-specific formatting to a ResponseBuilder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -17,7 +15,7 @@ from pydantic import BaseModel
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
lines: List[str] = []
|
lines: List[str] = []
|
||||||
if event:
|
if event:
|
||||||
lines.append(f"event: {event}")
|
lines.append(f"event: {event}")
|
||||||
|
|
@ -26,22 +24,28 @@ def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _sse_done() -> str:
|
def sse_done() -> str:
|
||||||
return "data: [DONE]\n\n"
|
return "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StreamContext:
|
class GenContext:
|
||||||
"""Shared state across the streaming generation lifecycle."""
|
"""Per-generation metadata passed to builder format methods."""
|
||||||
|
|
||||||
resp_id: str
|
resp_id: str
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
accumulated: str = ""
|
|
||||||
stop_matched: Optional[str] = None
|
|
||||||
last_yield_trimmed: str = ""
|
@dataclass
|
||||||
|
class StopInfo:
|
||||||
|
"""Stop-check result passed to format_stream_end / format_response."""
|
||||||
|
|
||||||
|
matched: Optional[str] = None
|
||||||
|
body: str = ""
|
||||||
|
yielded: str = ""
|
||||||
|
|
||||||
|
|
||||||
class StopChecker:
|
class StopChecker:
|
||||||
|
|
@ -56,95 +60,60 @@ class StopChecker:
|
||||||
return seq
|
return seq
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def trim(self, text: str, matched: str) -> str:
|
|
||||||
idx = text.rfind(matched)
|
|
||||||
return text[:idx] if idx != -1 else text
|
|
||||||
|
|
||||||
@property
|
class ResponseBuilder(ABC):
|
||||||
def has_sequences(self) -> bool:
|
"""Interface for protocol-specific response formatting.
|
||||||
return len(self._sequences) > 0
|
|
||||||
|
|
||||||
|
A new protocol requires one concrete builder implementing 5 methods.
|
||||||
class ProtocolHandler(ABC):
|
|
||||||
"""Template-method base for API protocol handlers.
|
|
||||||
|
|
||||||
Subclasses implement format hooks; the base class orchestrates the
|
|
||||||
generate-async loop and SSE/JSON response construction.
|
|
||||||
|
|
||||||
Lifecycle::
|
|
||||||
|
|
||||||
handle()
|
|
||||||
├─ build_prompt() # protocol-specific prompt assembly
|
|
||||||
├─ create_response_id() # unique response identifier
|
|
||||||
├─ [stream]
|
|
||||||
│ ├─ format_stream_start()
|
|
||||||
│ ├─ format_stream_token() × N
|
|
||||||
│ │ └─ on_token() hook for stop-sequence interception
|
|
||||||
│ └─ format_stream_end()
|
|
||||||
└─ [non-stream]
|
|
||||||
├─ (accumulate tokens)
|
|
||||||
└─ format_non_stream_response()
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_model: type[BaseModel]
|
@abstractmethod
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
||||||
|
|
||||||
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
@abstractmethod
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
"""SSE events that open the stream."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
"""SSE event for a single generated token."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
"""SSE events that close the stream."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""JSON response body for non-streaming mode."""
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolHandler:
|
||||||
|
"""Orchestrates the generation loop, delegates formatting to a builder.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||||
|
response = await handler.handle()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
||||||
|
):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self.builder = builder
|
||||||
@abstractmethod
|
|
||||||
def build_prompt(self) -> str:
|
|
||||||
"""Build the full prompt string from the request messages."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_response_id(self) -> str:
|
|
||||||
"""Generate a unique response ID following the protocol convention."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
|
||||||
"""Yield SSE events that open the stream (role marker, metadata)."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
|
||||||
"""Yield an SSE event for a single generated token."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
|
||||||
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_non_stream_response(
|
|
||||||
self, ctx: StreamContext, content: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Build the JSON response body for non-streaming mode."""
|
|
||||||
|
|
||||||
def get_stop_sequences(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def create_stop_checker(self) -> StopChecker:
|
|
||||||
return StopChecker(self.get_stop_sequences())
|
|
||||||
|
|
||||||
def on_token(
|
|
||||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Hook after each token is appended to accumulated.
|
|
||||||
|
|
||||||
Return a matched stop-sequence string to break the loop,
|
|
||||||
or None to continue.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||||
ctx = StreamContext(
|
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
||||||
resp_id=self.create_response_id(),
|
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
||||||
created=int(time.time()),
|
|
||||||
model=self.request.model,
|
|
||||||
prompt_tokens=self._count_prompt_tokens(),
|
|
||||||
)
|
|
||||||
|
|
||||||
agen = self.engine.generate_async(
|
agen = self.engine.generate_async(
|
||||||
prompt=self.build_prompt(),
|
prompt=prompt,
|
||||||
max_tokens=self.request.max_tokens,
|
max_tokens=self.request.max_tokens,
|
||||||
temperature=self.request.temperature,
|
temperature=self.request.temperature,
|
||||||
top_p=self.request.top_p,
|
top_p=self.request.top_p,
|
||||||
|
|
@ -152,33 +121,37 @@ class ProtocolHandler(ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.request.stream:
|
if self.request.stream:
|
||||||
return self._handle_stream(agen, ctx)
|
return self._handle_stream(agen, ctx, stop_sequences)
|
||||||
else:
|
else:
|
||||||
return await self._handle_non_stream(agen, ctx)
|
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
||||||
|
|
||||||
def _count_prompt_tokens(self) -> int:
|
def _handle_stream(
|
||||||
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||||
|
) -> StreamingResponse:
|
||||||
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
checker = StopChecker(stop_sequences)
|
||||||
stop_checker = self.create_stop_checker()
|
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
for event in self.format_stream_start(ctx):
|
for event in self.builder.format_stream_start(ctx):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
body = ""
|
||||||
|
yielded = ""
|
||||||
|
matched = None
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
body += token
|
||||||
ctx.accumulated += token
|
|
||||||
|
|
||||||
matched = self.on_token(ctx, token, stop_checker)
|
matched = checker.check(body)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield self.format_stream_token(ctx, token)
|
ctx.completion_tokens += 1
|
||||||
|
yield self.builder.format_chunk(token)
|
||||||
|
yielded += token
|
||||||
|
|
||||||
for event in self.format_stream_end(ctx):
|
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||||
|
for event in self.builder.format_stream_end(ctx, stop):
|
||||||
yield event
|
yield event
|
||||||
yield _sse_done()
|
yield sse_done()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_stream(),
|
event_stream(),
|
||||||
|
|
@ -186,249 +159,24 @@ class ProtocolHandler(ABC):
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
async def _handle_non_stream(
|
||||||
stop_checker = self.create_stop_checker()
|
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
checker = StopChecker(stop_sequences)
|
||||||
chunks: List[str] = []
|
chunks: List[str] = []
|
||||||
|
body = ""
|
||||||
|
matched = None
|
||||||
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
|
||||||
ctx.accumulated += token
|
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
|
body += token
|
||||||
|
|
||||||
matched = self.on_token(ctx, token, stop_checker)
|
matched = checker.check(body)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
ctx.completion_tokens += 1
|
||||||
|
|
||||||
content = "".join(chunks)
|
content = "".join(chunks)
|
||||||
return self.format_non_stream_response(ctx, content)
|
stop = StopInfo(matched=matched, body=body)
|
||||||
|
return self.builder.format_response(ctx, content, stop)
|
||||||
|
|
||||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
"""Extract plain text from an Anthropic content block (string or list)."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIHandler(ProtocolHandler):
|
|
||||||
"""OpenAI-compatible /v1/chat/completions handler."""
|
|
||||||
|
|
||||||
def build_prompt(self) -> str:
|
|
||||||
messages = [
|
|
||||||
{"role": m.role, "content": m.content} for m in self.request.messages
|
|
||||||
]
|
|
||||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
def create_response_id(self) -> str:
|
|
||||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant"},
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
|
||||||
return _sse_event(
|
|
||||||
{
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [
|
|
||||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_non_stream_response(
|
|
||||||
self, ctx: StreamContext, content: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": content},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicHandler(ProtocolHandler):
|
|
||||||
"""Anthropic-compatible /v1/messages handler."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._yielded = ""
|
|
||||||
|
|
||||||
def build_prompt(self) -> str:
|
|
||||||
messages: List[Dict[str, str]] = []
|
|
||||||
system = getattr(self.request, "system", None)
|
|
||||||
if system:
|
|
||||||
messages.append({"role": "system", "content": system})
|
|
||||||
for m in self.request.messages:
|
|
||||||
content = _extract_text_content(m.content)
|
|
||||||
if content:
|
|
||||||
messages.append({"role": m.role, "content": content})
|
|
||||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
def create_response_id(self) -> str:
|
|
||||||
return f"msg_{uuid.uuid4().hex[:24]}"
|
|
||||||
|
|
||||||
def get_stop_sequences(self) -> List[str]:
|
|
||||||
return getattr(self.request, "stop_sequences", None) or []
|
|
||||||
|
|
||||||
def on_token(
|
|
||||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
|
||||||
) -> Optional[str]:
|
|
||||||
matched = stop_checker.check(ctx.accumulated)
|
|
||||||
if not matched:
|
|
||||||
return None
|
|
||||||
|
|
||||||
ctx.stop_matched = matched
|
|
||||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
|
||||||
unyielded = trimmed[len(self._yielded) :]
|
|
||||||
if unyielded:
|
|
||||||
ctx.last_yield_trimmed = unyielded
|
|
||||||
return matched
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [],
|
|
||||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
event="message_start",
|
|
||||||
),
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
event="content_block_start",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
|
||||||
self._yielded += token
|
|
||||||
return _sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": token},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
|
||||||
matched = ctx.stop_matched
|
|
||||||
events: List[str] = []
|
|
||||||
last_yielded = ctx.last_yield_trimmed
|
|
||||||
if last_yielded:
|
|
||||||
events.append(
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": last_yielded},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
_sse_event(
|
|
||||||
{"type": "content_block_stop", "index": 0},
|
|
||||||
event="content_block_stop",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {
|
|
||||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
|
||||||
"stop_sequence": matched,
|
|
||||||
},
|
|
||||||
"usage": {"output_tokens": ctx.completion_tokens},
|
|
||||||
},
|
|
||||||
event="message_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
|
||||||
return events
|
|
||||||
|
|
||||||
def format_non_stream_response(
|
|
||||||
self, ctx: StreamContext, content: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
matched = ctx.stop_matched
|
|
||||||
if matched:
|
|
||||||
content = content[: content.rfind(matched)]
|
|
||||||
return {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
|
||||||
"stop_sequence": matched,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": ctx.prompt_tokens,
|
|
||||||
"output_tokens": ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
|
||||||
|
|
||||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||||
|
|
||||||
|
``app`` is lazily constructed — importing this module does NOT create a FastAPI instance.
|
||||||
|
Use :func:`get_app` to access the singleton.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -12,17 +15,19 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import APIRouter, FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||||
|
from astrai.inference.api.protocol import ProtocolHandler
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_app_instance: Optional[FastAPI] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
|
@ -67,14 +72,30 @@ class MessagesRequest(BaseModel):
|
||||||
stop_sequences: Optional[List[str]] = None
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
config = app.state.server_config
|
||||||
|
if not config.get("_test", False):
|
||||||
|
try:
|
||||||
|
app.state.engine = _create_engine(**config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
yield
|
||||||
|
if app.state.engine:
|
||||||
|
app.state.engine.shutdown()
|
||||||
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _create_engine(
|
def _create_engine(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Path,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
) -> InferenceEngine:
|
) -> InferenceEngine:
|
||||||
if param_path is None:
|
|
||||||
param_path = _project_root / "params"
|
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
|
|
@ -92,67 +113,66 @@ def _create_engine(
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
def get_app() -> FastAPI:
|
||||||
async def lifespan(app: FastAPI):
|
"""Return the singleton FastAPI instance (lazily created on first call)."""
|
||||||
config = app.state.server_config
|
global _app_instance
|
||||||
if not config.get("_test", False):
|
if _app_instance is None:
|
||||||
try:
|
_app_instance = FastAPI(
|
||||||
app.state.engine = _create_engine(**config)
|
title="AstrAI Inference Server",
|
||||||
except Exception as e:
|
version="0.2.0",
|
||||||
logger.error(f"Failed to load model: {e}")
|
lifespan=lifespan,
|
||||||
raise
|
)
|
||||||
yield
|
_app_instance.include_router(router)
|
||||||
if app.state.engine:
|
_app_instance.state.server_config = {}
|
||||||
app.state.engine.shutdown()
|
_app_instance.state.engine = None
|
||||||
logger.info("Inference engine shutdown complete")
|
return _app_instance
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
def _get_engine() -> InferenceEngine:
|
||||||
|
engine = get_app().state.engine
|
||||||
|
|
||||||
def _get_engine(request: Request) -> InferenceEngine:
|
|
||||||
engine = request.app.state.engine
|
|
||||||
if engine is None:
|
if engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@router.get("/health")
|
||||||
async def health(request: Request):
|
async def health():
|
||||||
|
app = get_app()
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"model_loaded": request.app.state.engine is not None,
|
"model_loaded": app.state.engine is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
@router.get("/stats")
|
||||||
async def get_stats(request: Request):
|
async def get_stats():
|
||||||
return _get_engine(request).get_stats()
|
return _get_engine().get_stats()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
engine = _get_engine(req)
|
engine = _get_engine()
|
||||||
handler = OpenAIHandler(request, engine)
|
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
@router.post("/v1/messages")
|
||||||
async def create_message(request: MessagesRequest, req: Request):
|
async def create_message(request: MessagesRequest):
|
||||||
engine = _get_engine(req)
|
engine = _get_engine()
|
||||||
handler = AnthropicHandler(request, engine)
|
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
|
param_path: Path,
|
||||||
host: str = "0.0.0.0",
|
host: str = "0.0.0.0",
|
||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
reload: bool = False,
|
reload: bool = False,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
param_path: Optional[Path] = None,
|
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
|
app = get_app()
|
||||||
app.state.server_config = {
|
app.state.server_config = {
|
||||||
"device": device,
|
"device": device,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
|
|
@ -163,4 +183,5 @@ def run_server(
|
||||||
app,
|
app,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
|
reload=reload,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class Allocator:
|
||||||
return idx
|
return idx
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
def free(self, idx: int, keep_cached: bool = False):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._refs[idx] -= 1
|
self._refs[idx] -= 1
|
||||||
if self._refs[idx] == 0:
|
if self._refs[idx] == 0:
|
||||||
|
|
@ -51,7 +51,7 @@ class Allocator:
|
||||||
else:
|
else:
|
||||||
self._free_mask |= 1 << idx
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
def inc_ref(self, idx: int) -> None:
|
def inc_ref(self, idx: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._refs[idx] += 1
|
self._refs[idx] += 1
|
||||||
self._lru.pop(idx, None)
|
self._lru.pop(idx, None)
|
||||||
|
|
@ -60,7 +60,7 @@ class Allocator:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._refs[idx]
|
return self._refs[idx]
|
||||||
|
|
||||||
def touch(self, idx: int) -> None:
|
def touch(self, idx: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._lru.move_to_end(idx)
|
self._lru.move_to_end(idx)
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ class PrefixCache:
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def evict(self, idx: int) -> None:
|
def evict(self, idx: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
h = self._page_to_hash.pop(idx, None)
|
h = self._page_to_hash.pop(idx, None)
|
||||||
if h is not None:
|
if h is not None:
|
||||||
|
|
@ -96,9 +96,7 @@ class PrefixCache:
|
||||||
hits.append(p)
|
hits.append(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def record(
|
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
|
||||||
) -> None:
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||||
old_h = self._page_to_hash.pop(page_idx, None)
|
old_h = self._page_to_hash.pop(page_idx, None)
|
||||||
|
|
@ -127,13 +125,13 @@ class PagePool:
|
||||||
def alloc(self) -> int:
|
def alloc(self) -> int:
|
||||||
return self._alloc.alloc()
|
return self._alloc.alloc()
|
||||||
|
|
||||||
def free(self, idx: int) -> None:
|
def free(self, idx: int):
|
||||||
keep = self._prefix.has_page(idx)
|
keep = self._prefix.has_page(idx)
|
||||||
self._alloc.free(idx, keep_cached=keep)
|
self._alloc.free(idx, keep_cached=keep)
|
||||||
if not keep:
|
if not keep:
|
||||||
self._prefix.evict(idx)
|
self._prefix.evict(idx)
|
||||||
|
|
||||||
def inc_ref(self, idx: int) -> None:
|
def inc_ref(self, idx: int):
|
||||||
self._alloc.inc_ref(idx)
|
self._alloc.inc_ref(idx)
|
||||||
|
|
||||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||||
|
|
@ -142,9 +140,7 @@ class PagePool:
|
||||||
self._alloc.touch(p)
|
self._alloc.touch(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def record(
|
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
|
||||||
) -> None:
|
|
||||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -157,7 +153,7 @@ class TaskTable:
|
||||||
self._cached: Dict[str, int] = {}
|
self._cached: Dict[str, int] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
def set(self, task_id: str, page_table: List[int], cached: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._pages[task_id] = page_table
|
self._pages[task_id] = page_table
|
||||||
self._cached[task_id] = cached
|
self._cached[task_id] = cached
|
||||||
|
|
@ -220,7 +216,7 @@ class Storage:
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
k: Tensor,
|
k: Tensor,
|
||||||
v: Tensor,
|
v: Tensor,
|
||||||
) -> None:
|
):
|
||||||
seq_len = k.size(1)
|
seq_len = k.size(1)
|
||||||
if seq_len == 0:
|
if seq_len == 0:
|
||||||
return
|
return
|
||||||
|
|
@ -286,7 +282,7 @@ class KvcacheView:
|
||||||
self._page_table = page_table
|
self._page_table = page_table
|
||||||
self._total_len = total_len
|
self._total_len = total_len
|
||||||
|
|
||||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
||||||
start_pos = self._total_len - k.size(1)
|
start_pos = self._total_len - k.size(1)
|
||||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
|
|
@ -339,7 +335,7 @@ class KVCache:
|
||||||
self._table.set(task_id, hits + new_pages, cached)
|
self._table.set(task_id, hits + new_pages, cached)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def task_free(self, task_id: str) -> None:
|
def task_free(self, task_id: str):
|
||||||
page_table, _ = self._table.pop(task_id)
|
page_table, _ = self._table.pop(task_id)
|
||||||
for idx in page_table:
|
for idx in page_table:
|
||||||
self._pool.free(idx)
|
self._pool.free(idx)
|
||||||
|
|
@ -359,7 +355,7 @@ class KVCache:
|
||||||
|
|
||||||
def task_record_hashes(
|
def task_record_hashes(
|
||||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||||
) -> None:
|
):
|
||||||
page_table = self._table.get(task_id)
|
page_table = self._table.get(task_id)
|
||||||
full_pages = len(prompt_ids) // self.page_size
|
full_pages = len(prompt_ids) // self.page_size
|
||||||
for i in range(start_logical_page, full_pages):
|
for i in range(start_logical_page, full_pages):
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,7 @@ class Executor:
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
def execute_prefill(
|
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
||||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
|
||||||
) -> None:
|
|
||||||
if start_pos >= prompt_len:
|
if start_pos >= prompt_len:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,14 +22,22 @@ class InferenceScheduler:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prompt_len: int = 512,
|
max_prompt_len: int = 2048,
|
||||||
page_size: int = 64,
|
page_size: int = 64,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
if max_seq_len is not None:
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
elif config.max_len is not None:
|
||||||
|
self.max_seq_len = config.max_len
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"max_seq_len must be provided either as argument "
|
||||||
|
"or in model config (config.max_len)"
|
||||||
|
)
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
|
@ -63,18 +71,19 @@ class InferenceScheduler:
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._fatal_error: Optional[Exception] = None
|
||||||
|
|
||||||
def add_task(self, prompt: str, **kwargs) -> str:
|
def add_task(self, prompt: str, **kwargs) -> str:
|
||||||
return self._task_mgr.add_task(prompt, **kwargs)
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> None:
|
def remove_task(self, task_id: str):
|
||||||
for task in self._task_mgr.remove_task(task_id):
|
for task in self._task_mgr.remove_task(task_id):
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return self._task_mgr.get_stats()
|
return self._task_mgr.get_stats()
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self):
|
||||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
|
|
@ -100,7 +109,10 @@ class InferenceScheduler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_prefill = [
|
to_prefill = [
|
||||||
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
t
|
||||||
|
for t in self._task_mgr.get_active_tasks()
|
||||||
|
if t.output_tokens == 0
|
||||||
|
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
||||||
]
|
]
|
||||||
if to_prefill:
|
if to_prefill:
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
|
|
@ -148,11 +160,15 @@ class InferenceScheduler:
|
||||||
t.output_ids.append(ntok)
|
t.output_ids.append(ntok)
|
||||||
t.output_tokens += 1
|
t.output_tokens += 1
|
||||||
pos = t.input_tokens + t.output_tokens
|
pos = t.input_tokens + t.output_tokens
|
||||||
self._page_cache.task_extend(t.task_id, pos)
|
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
||||||
if t.stream_callback:
|
if t.stream_callback:
|
||||||
t.stream_callback(
|
t.stream_callback(
|
||||||
self._task_mgr.tokenizer.decode([ntok])
|
self._task_mgr.tokenizer.decode([ntok])
|
||||||
)
|
)
|
||||||
|
if not extend_ok:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
for t in valid:
|
for t in valid:
|
||||||
if t.is_finished(stop_ids):
|
if t.is_finished(stop_ids):
|
||||||
|
|
@ -160,28 +176,37 @@ class InferenceScheduler:
|
||||||
t.stream_callback(STOP)
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self._fatal_error = e
|
||||||
|
self._running = False
|
||||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
for task in self._task_mgr.get_waiting_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
raise
|
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self):
|
||||||
if not self._running:
|
if not self._running:
|
||||||
self._running = True
|
self._running = True
|
||||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
self._loop_thread = t
|
self._loop_thread = t
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self):
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task_mgr.wake()
|
self._task_mgr.wake()
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=2.0)
|
self._loop_thread.join(timeout=2.0)
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
for task in self._task_mgr.get_waiting_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -172,12 +172,12 @@ class TaskManager:
|
||||||
to_add.append(self.waiting_queue.popleft())
|
to_add.append(self.waiting_queue.popleft())
|
||||||
return to_add
|
return to_add
|
||||||
|
|
||||||
def activate(self, task: Task) -> None:
|
def activate(self, task: Task):
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
def return_to_waiting(self, tasks: List[Task]) -> None:
|
def return_to_waiting(self, tasks: List[Task]):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for task in reversed(tasks):
|
for task in reversed(tasks):
|
||||||
self.waiting_queue.appendleft(task)
|
self.waiting_queue.appendleft(task)
|
||||||
|
|
@ -185,7 +185,10 @@ class TaskManager:
|
||||||
def has_work(self) -> bool:
|
def has_work(self) -> bool:
|
||||||
return bool(self.active_tasks or self.waiting_queue)
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
||||||
def wait_for_tasks(self, timeout: float = 1.0) -> None:
|
def wait_for_tasks(self, timeout: float = 1.0):
|
||||||
|
with self._lock:
|
||||||
|
if self.waiting_queue or self.active_tasks:
|
||||||
|
return
|
||||||
self._task_event.clear()
|
self._task_event.clear()
|
||||||
self._task_event.wait(timeout=timeout)
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
|
@ -193,10 +196,14 @@ class TaskManager:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return list(self.active_tasks)
|
return list(self.active_tasks)
|
||||||
|
|
||||||
def clear_queues(self) -> None:
|
def get_waiting_tasks(self) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
return list(self.waiting_queue)
|
||||||
|
|
||||||
|
def clear_queues(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue.clear()
|
self.waiting_queue.clear()
|
||||||
self.active_tasks.clear()
|
self.active_tasks.clear()
|
||||||
|
|
||||||
def wake(self) -> None:
|
def wake(self):
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
|
|
|
||||||
|
|
@ -13,17 +13,6 @@ from astrai.inference.core.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def _validate_sampling_params(
|
|
||||||
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
|
||||||
):
|
|
||||||
if not (isinstance(top_k, int) and top_k >= 0):
|
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
|
||||||
raise ValueError("temperature must be a non-negative number")
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateResult:
|
class GenerateResult:
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
|
@ -59,7 +48,7 @@ class GenerateResult:
|
||||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
return self._event.wait(timeout=timeout)
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
def wait_completion(self, timeout: float = 300.0) -> None:
|
def wait_completion(self, timeout: float = 300.0):
|
||||||
with self._cond:
|
with self._cond:
|
||||||
if not self._cond.wait_for(
|
if not self._cond.wait_for(
|
||||||
lambda: self._completed >= self._total, timeout=timeout
|
lambda: self._completed >= self._total, timeout=timeout
|
||||||
|
|
@ -86,7 +75,12 @@ class GenerationRequest:
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
if not (isinstance(top_k, int) and top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
||||||
|
raise ValueError("temperature must be a positive number")
|
||||||
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
@ -137,7 +131,6 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
|
|
@ -158,7 +151,6 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
|
||||||
sync_gen = self._generate_streaming(
|
sync_gen = self._generate_streaming(
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
@ -289,7 +281,7 @@ class InferenceEngine:
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return self.scheduler.get_stats()
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self):
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -44,10 +44,12 @@ class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
t = self.temperature
|
t = self.temperature
|
||||||
if isinstance(t, Tensor):
|
if isinstance(t, Tensor):
|
||||||
|
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||||
|
t = torch.clamp(t, min=1e-8)
|
||||||
if (t != 1.0).any():
|
if (t != 1.0).any():
|
||||||
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
|
||||||
elif t != 1.0:
|
|
||||||
logits = logits / t
|
logits = logits / t
|
||||||
|
elif t != 1.0:
|
||||||
|
logits = logits / max(t, 1e-8)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,18 @@
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.components.attention import GQA
|
||||||
GQA,
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
MLP,
|
from astrai.model.components.linear import Linear
|
||||||
DecoderBlock,
|
from astrai.model.components.lora import (
|
||||||
Linear,
|
LoRAConfig,
|
||||||
RMSNorm,
|
inject_lora,
|
||||||
|
load_lora,
|
||||||
|
merge_lora,
|
||||||
|
save_lora,
|
||||||
)
|
)
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.components.mlp import MLP
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Modules
|
# Modules
|
||||||
|
|
@ -16,6 +22,13 @@ __all__ = [
|
||||||
"GQA",
|
"GQA",
|
||||||
"DecoderBlock",
|
"DecoderBlock",
|
||||||
# Models
|
# Models
|
||||||
"Transformer",
|
"AutoRegressiveLM",
|
||||||
|
"EmbeddingEncoder",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
|
# LoRA
|
||||||
|
"LoRAConfig",
|
||||||
|
"inject_lora",
|
||||||
|
"merge_lora",
|
||||||
|
"save_lora",
|
||||||
|
"load_lora",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -6,16 +6,20 @@ from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Self, Union
|
from typing import Self, Union
|
||||||
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.serialization import load_model_config, load_model_weights, save_model
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _disable_random_init(enable: bool = True):
|
def _disable_random_init(enable: bool = True):
|
||||||
init_functions = [
|
if not enable:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
names = (
|
||||||
"xavier_normal_",
|
"xavier_normal_",
|
||||||
"xavier_uniform_",
|
"xavier_uniform_",
|
||||||
"kaiming_normal_",
|
"kaiming_normal_",
|
||||||
|
|
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
|
||||||
"constant_",
|
"constant_",
|
||||||
"normal_",
|
"normal_",
|
||||||
"uniform_",
|
"uniform_",
|
||||||
]
|
)
|
||||||
original_funcs = {}
|
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
||||||
for name in init_functions:
|
for n in orig:
|
||||||
if enable and hasattr(nn.init, name):
|
setattr(nn.init, n, lambda *a, **kw: None)
|
||||||
original_funcs[name] = getattr(nn.init, name)
|
|
||||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
if enable:
|
for n, fn in orig.items():
|
||||||
for name, orig_func in original_funcs.items():
|
setattr(nn.init, n, fn)
|
||||||
setattr(nn.init, name, orig_func)
|
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
|
|
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
Provides model loading/saving, registration, and generation.
|
Provides model loading/saving, registration, and generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: BaseModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|
@ -59,24 +60,22 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
|
|
||||||
model_path = Path(path)
|
model_path = Path(path)
|
||||||
|
|
||||||
# Load config
|
|
||||||
config = ModelConfig()
|
|
||||||
config_path = model_path / "config.json"
|
config_path = model_path / "config.json"
|
||||||
if config_path.exists():
|
if not config_path.exists():
|
||||||
config.load(str(config_path))
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
model_type = config.model_type or "transformer"
|
raw = load_model_config(str(model_path))
|
||||||
|
config = ConfigFactory.load(raw)
|
||||||
|
model_type = config.model_type or "autoregressive_lm"
|
||||||
|
|
||||||
actual_cls = AutoModel.get_component_class(model_type)
|
actual_cls = AutoModel.get_component_class(model_type)
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
with _disable_random_init(enable=disable_random_init):
|
||||||
model = actual_cls(config)
|
model = actual_cls(config)
|
||||||
|
|
||||||
# Load weights
|
|
||||||
weights_path = model_path / "model.safetensors"
|
weights_path = model_path / "model.safetensors"
|
||||||
if weights_path.exists():
|
if weights_path.exists():
|
||||||
state_dict = st.load_file(str(weights_path))
|
state_dict = load_model_weights(str(model_path))
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
@ -84,15 +83,12 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, Path],
|
save_directory: Union[str, Path],
|
||||||
) -> None:
|
):
|
||||||
save_path = Path(save_directory)
|
save_model(
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
config=self.config.to_dict(),
|
||||||
|
state_dict=self.state_dict(),
|
||||||
# Save config
|
save_directory=str(save_directory),
|
||||||
self.config.save(str(save_path / "config.json"))
|
)
|
||||||
|
|
||||||
# Save weights
|
|
||||||
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> Self:
|
def to(self, *args, **kwargs) -> Self:
|
||||||
"""Move model to device/dtype."""
|
"""Move model to device/dtype."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.mlp import MLP
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import (
|
||||||
|
RotaryEmbedding,
|
||||||
|
apply_rotary_emb,
|
||||||
|
get_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Linear",
|
||||||
|
"RMSNorm",
|
||||||
|
"MLP",
|
||||||
|
"Embedding",
|
||||||
|
"GQA",
|
||||||
|
"MLA",
|
||||||
|
"DecoderBlock",
|
||||||
|
"RotaryEmbedding",
|
||||||
|
"apply_rotary_emb",
|
||||||
|
"get_rotary_emb",
|
||||||
|
"repeat_kv",
|
||||||
|
]
|
||||||
|
|
@ -5,11 +5,14 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
from astrai.inference.core.cache import KvcacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
"""Repeat KV heads n_rep times for GQA."""
|
|
||||||
bs, slen, n_heads, head_dim = x.shape
|
bs, slen, n_heads, head_dim = x.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return x
|
return x
|
||||||
|
|
@ -20,88 +23,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_emb(
|
class AttnFactory(BaseFactory[nn.Module]):
|
||||||
dim: int,
|
@classmethod
|
||||||
max_len: int,
|
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
||||||
base: float = 10000,
|
return super().create(attn_type, **kwargs)
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
|
||||||
freqs = torch.outer(t, theta).float()
|
|
||||||
cos = torch.cos(freqs)
|
|
||||||
sin = torch.sin(freqs)
|
|
||||||
return torch.complex(cos, sin)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|
||||||
dtype = x.dtype
|
|
||||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
||||||
x_complex = torch.view_as_complex(x_)
|
|
||||||
freqs_cis = freqs_cis.unsqueeze(2)
|
|
||||||
x_rotated = x_complex * freqs_cis
|
|
||||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
|
||||||
return x_out.to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.max_len = max_len
|
|
||||||
self.base = base
|
|
||||||
self._set_rotary_buffer(self.max_len)
|
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
|
||||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
|
||||||
freqs_cis = torch.view_as_real(rotary_emb)
|
|
||||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = (
|
|
||||||
torch.arange(x.size(1), device=x.device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(x.size(0), -1)
|
|
||||||
)
|
|
||||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
|
||||||
return torch.view_as_complex(position_freq_cis)
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.linear(x, self.weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim, norm_eps):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.normalized_shape = (dim,)
|
|
||||||
self.norm_eps = norm_eps
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
|
||||||
def __init__(self, dim: int, dim_feed_forward: int):
|
|
||||||
super().__init__()
|
|
||||||
self.up = Linear(dim, dim_feed_forward)
|
|
||||||
self.gate = Linear(dim, dim_feed_forward)
|
|
||||||
self.down = Linear(dim_feed_forward, dim)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
|
||||||
out = self.down(gated)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
|
@AttnFactory.register("gqa")
|
||||||
class GQA(nn.Module):
|
class GQA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -152,7 +80,6 @@ class GQA(nn.Module):
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
is_causal = attn_mask is None
|
is_causal = attn_mask is None
|
||||||
|
|
||||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||||
|
|
@ -167,7 +94,6 @@ class GQA(nn.Module):
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
sdqa_out = (
|
sdqa_out = (
|
||||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||||
|
|
@ -183,6 +109,7 @@ class GQA(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@AttnFactory.register("mla")
|
||||||
class MLA(nn.Module):
|
class MLA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -193,6 +120,7 @@ class MLA(nn.Module):
|
||||||
qk_nope_head_dim: int,
|
qk_nope_head_dim: int,
|
||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
|
use_qk_norm: bool,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
):
|
):
|
||||||
|
|
@ -206,16 +134,20 @@ class MLA(nn.Module):
|
||||||
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.n_rep = n_heads // n_kv_heads
|
self.n_rep = n_heads // n_kv_heads
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
self.use_gated_attention = use_gated_attention
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
# fused KV: (k_nope, k_rope, v)
|
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
n_kv_heads * (2 * self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = Linear(dim, dim, bias=False)
|
self.o_proj = Linear(dim, dim, bias=False)
|
||||||
|
|
@ -248,7 +180,7 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
q_nope, q_rope = (
|
q_nope, q_rope = (
|
||||||
q[..., : self.qk_nope_head_dim],
|
q[..., : self.qk_nope_head_dim],
|
||||||
q[..., self.qk_rope_head_dim :],
|
q[..., self.qk_nope_head_dim :],
|
||||||
)
|
)
|
||||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||||
|
|
@ -256,6 +188,10 @@ class MLA(nn.Module):
|
||||||
q = torch.cat([q_nope, q_rope], dim=-1)
|
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if paged_cache is not None:
|
if paged_cache is not None:
|
||||||
paged_cache.write(self.layer_id, k, v)
|
paged_cache.write(self.layer_id, k, v)
|
||||||
k, v = paged_cache.gather(self.layer_id)
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
|
|
@ -274,57 +210,3 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
out = self.o_proj(attn_out)
|
out = self.o_proj(attn_out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_heads: int,
|
|
||||||
dim_ffn: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
norm_eps: int,
|
|
||||||
use_qk_norm: bool,
|
|
||||||
use_gated_attention: bool,
|
|
||||||
layer_id: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.attention = GQA(
|
|
||||||
dim,
|
|
||||||
n_heads,
|
|
||||||
n_kv_heads,
|
|
||||||
use_qk_norm,
|
|
||||||
norm_eps,
|
|
||||||
use_gated_attention,
|
|
||||||
layer_id,
|
|
||||||
)
|
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
|
||||||
self.mlp = MLP(dim, dim_ffn)
|
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
rotary_emb: Tensor,
|
|
||||||
attention_mask: Optional[Tensor] = None,
|
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
attn_output = self.attention(
|
|
||||||
self.input_norm(x),
|
|
||||||
rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
paged_cache,
|
|
||||||
)
|
|
||||||
x = attn_output + x
|
|
||||||
x = self.mlp(self.post_attention_norm(x)) + x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.embedding(x, self.weight)
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.attention import AttnFactory
|
||||||
|
from astrai.model.components.mlp import FFNFactory
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int,
|
||||||
|
attn_type: str = "gqa",
|
||||||
|
ffn_type: str = "mlp",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = AttnFactory.create(
|
||||||
|
attn_type,
|
||||||
|
dim=dim,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
use_qk_norm=use_qk_norm,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
use_gated_attention=use_gated_attention,
|
||||||
|
layer_id=layer_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
|
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rotary_emb: Tensor,
|
||||||
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
attn_output = self.attention(
|
||||||
|
self.input_norm(x),
|
||||||
|
rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
paged_cache,
|
||||||
|
)
|
||||||
|
x = attn_output + x
|
||||||
|
x = self.mlp(self.post_attention_norm(x)) + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Module):
|
||||||
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.embedding(x, self.weight)
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||||
|
bound = 1 / (fan_in**0.5)
|
||||||
|
nn.init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
@ -0,0 +1,194 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.serialization import (
|
||||||
|
load_json,
|
||||||
|
load_safetensors,
|
||||||
|
save_json,
|
||||||
|
save_safetensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
||||||
|
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
r: int = 16
|
||||||
|
alpha: int = 32
|
||||||
|
target_modules: tuple = ("q_proj", "v_proj")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
self.register_parameter("weight", base.weight)
|
||||||
|
self.weight.requires_grad_(False)
|
||||||
|
self.bias = base.bias
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.requires_grad_(False)
|
||||||
|
|
||||||
|
self.r = r
|
||||||
|
self.scaling = alpha / r
|
||||||
|
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
||||||
|
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
||||||
|
self._merged = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.linear(x, self.weight, self.bias)
|
||||||
|
if not self._merged:
|
||||||
|
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
||||||
|
return out
|
||||||
|
|
||||||
|
def merge(self):
|
||||||
|
if self._merged:
|
||||||
|
return
|
||||||
|
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
self._merged = True
|
||||||
|
del self.lora_A
|
||||||
|
del self.lora_B
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_lora_info(model: nn.Module) -> dict:
|
||||||
|
names = {}
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, Linear):
|
||||||
|
_, _, child = n.rpartition(".")
|
||||||
|
names.setdefault(child, []).append(n)
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_count(model: nn.Module) -> int:
|
||||||
|
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
||||||
|
|
||||||
|
|
||||||
|
def inject_lora(
|
||||||
|
model: nn.Module,
|
||||||
|
r: int = 16,
|
||||||
|
alpha: int = 32,
|
||||||
|
target_modules: Optional[Set[str]] = None,
|
||||||
|
) -> LoRAConfig:
|
||||||
|
if target_modules is None:
|
||||||
|
target_modules = TARGET_MODULES_ATTN
|
||||||
|
|
||||||
|
available = _collect_lora_info(model)
|
||||||
|
injected = 0
|
||||||
|
|
||||||
|
for name, module in list(model.named_modules()):
|
||||||
|
if not isinstance(module, Linear):
|
||||||
|
continue
|
||||||
|
parent_name, _, child_name = name.rpartition(".")
|
||||||
|
if child_name not in target_modules:
|
||||||
|
continue
|
||||||
|
parent = model.get_submodule(parent_name) if parent_name else model
|
||||||
|
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
||||||
|
injected += 1
|
||||||
|
|
||||||
|
if injected == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No LoRA layers injected. Available Linear child names: %s. "
|
||||||
|
"target_modules: %s. Check model type and target_modules.",
|
||||||
|
sorted(available),
|
||||||
|
sorted(target_modules),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
||||||
|
|
||||||
|
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora(model: nn.Module):
|
||||||
|
n = 0
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, LoRALinear):
|
||||||
|
module.merge()
|
||||||
|
n += 1
|
||||||
|
if n == 0:
|
||||||
|
logger.warning("No LoRA layers to merge.")
|
||||||
|
else:
|
||||||
|
logger.info("Merged %d LoRA layers", n)
|
||||||
|
|
||||||
|
|
||||||
|
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
||||||
|
lora_sd = {
|
||||||
|
k: v
|
||||||
|
for k, v in model.state_dict().items()
|
||||||
|
if k.endswith((".lora_A", ".lora_B"))
|
||||||
|
}
|
||||||
|
if not lora_sd:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No LoRA parameters found in model. "
|
||||||
|
"The model may not have been injected or was already merged."
|
||||||
|
)
|
||||||
|
|
||||||
|
path = Path(save_dir)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
||||||
|
save_json(asdict(config), path / "adapter_config.json")
|
||||||
|
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||||
|
path = Path(load_dir)
|
||||||
|
raw = load_json(path / "adapter_config.json")
|
||||||
|
config = LoRAConfig(
|
||||||
|
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = _get_lora_count(model)
|
||||||
|
if existing > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Model already has %d LoRA layers. Skipping injection, "
|
||||||
|
"loading weights onto existing layers only.",
|
||||||
|
existing,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inject_lora(
|
||||||
|
model,
|
||||||
|
r=config.r,
|
||||||
|
alpha=config.alpha,
|
||||||
|
target_modules=set(config.target_modules),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = load_safetensors(path / "adapter_model.safetensors")
|
||||||
|
try:
|
||||||
|
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||||
|
except RuntimeError as e:
|
||||||
|
msg = str(e)
|
||||||
|
if "size mismatch" in msg:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LoRA weight shapes do not match the model. "
|
||||||
|
f"The adapter config (r={config.r}) may not match the injected layers. "
|
||||||
|
f"Original error: {msg}"
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
injected = _get_lora_count(model)
|
||||||
|
if injected == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No LoRA layers found after loading. "
|
||||||
|
"Inject LoRA before calling load_lora, or check the adapter config."
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
lora_missing = [k for k in missing if "lora" in k]
|
||||||
|
if lora_missing:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LoRA weight keys not found in model: {lora_missing}. "
|
||||||
|
f"The adapter config (r={config.r}) may not match the model."
|
||||||
|
)
|
||||||
|
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
||||||
|
if unexpected:
|
||||||
|
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
||||||
|
|
||||||
|
logger.info("LoRA adapter loaded from %s", load_dir)
|
||||||
|
return config
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
|
||||||
|
|
||||||
|
class FFNFactory(BaseFactory[nn.Module]):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
||||||
|
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@FFNFactory.register("mlp")
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim: int, dim_ffn: int):
|
||||||
|
super().__init__()
|
||||||
|
self.up = Linear(dim, dim_ffn)
|
||||||
|
self.gate = Linear(dim, dim_ffn)
|
||||||
|
self.down = Linear(dim_ffn, dim)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
|
out = self.down(gated)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@FFNFactory.register("moe")
|
||||||
|
class DeepSeekMoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
n_routed_experts: int,
|
||||||
|
n_shared_experts: int = 1,
|
||||||
|
n_activated_experts: int = 2,
|
||||||
|
topk_method: str = "greedy",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_activated_experts = n_activated_experts
|
||||||
|
self.topk_method = topk_method
|
||||||
|
|
||||||
|
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||||
|
|
||||||
|
self.shared_experts = nn.ModuleList(
|
||||||
|
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
||||||
|
)
|
||||||
|
self.routed_experts = nn.ModuleList(
|
||||||
|
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
bsz, seq_len, dim = x.shape
|
||||||
|
x_flat = x.view(-1, dim)
|
||||||
|
|
||||||
|
shared_out = self._shared_forward(x_flat)
|
||||||
|
routed_out = self._routed_forward(x_flat)
|
||||||
|
|
||||||
|
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _shared_forward(self, x: Tensor) -> Tensor:
|
||||||
|
if self.n_shared_experts == 0:
|
||||||
|
return torch.zeros_like(x)
|
||||||
|
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
||||||
|
|
||||||
|
def _routed_forward(self, x: Tensor) -> Tensor:
|
||||||
|
N, D = x.shape
|
||||||
|
K = self.n_activated_experts
|
||||||
|
|
||||||
|
router_logits = self.router(x)
|
||||||
|
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
||||||
|
for expert_idx in range(self.n_routed_experts):
|
||||||
|
expert_mask = topk_indices == expert_idx
|
||||||
|
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
||||||
|
if token_idx.numel() == 0:
|
||||||
|
continue
|
||||||
|
expert_input = x[token_idx]
|
||||||
|
expert_output = self.routed_experts[expert_idx](expert_input)
|
||||||
|
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
||||||
|
output.index_add_(0, token_idx, expert_output * weights)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, norm_eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.normalized_shape = (dim,)
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotary_emb(
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
|
freqs = torch.outer(t, theta).float()
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def ntk_base(base: float, dim: int, factor: float) -> float:
|
||||||
|
return base * (factor ** (dim / (dim - 2)))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
dtype = x.dtype
|
||||||
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||||
|
x_complex = torch.view_as_complex(x_)
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
x_rotated = x_complex * freqs_cis
|
||||||
|
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||||
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.max_len = max_len
|
||||||
|
self.base = base
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
if rope_scaling is not None:
|
||||||
|
scaling_type = rope_scaling.get("type", "ntk")
|
||||||
|
factor = rope_scaling.get("factor", 1.0)
|
||||||
|
if scaling_type == "ntk":
|
||||||
|
self.base = ntk_base(base, dim, factor)
|
||||||
|
|
||||||
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
|
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||||
|
freqs_cis = torch.view_as_real(rotary_emb)
|
||||||
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(x.size(1), device=x.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(x.size(0), -1)
|
||||||
|
)
|
||||||
|
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||||
|
return torch.view_as_complex(position_freq_cis)
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.config.model_config import EncoderConfig
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import RotaryEmbedding
|
||||||
|
from astrai.model.transformer import process_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
@AutoModel.register("embedding")
|
||||||
|
class EmbeddingEncoder(AutoModel):
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
rope_dim = config.dim // config.n_heads
|
||||||
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
|
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||||
|
)
|
||||||
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DecoderBlock(
|
||||||
|
config.dim,
|
||||||
|
config.n_heads,
|
||||||
|
config.dim_ffn,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.norm_eps,
|
||||||
|
config.use_qk_norm,
|
||||||
|
config.use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
|
|
||||||
|
self.pooling_type = config.pooling_type or "mean"
|
||||||
|
self.normalize_embeddings = config.normalize_embeddings or False
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if hasattr(module, "reset_parameters"):
|
||||||
|
module.reset_parameters()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
state_dict.pop("lm_head.weight", None)
|
||||||
|
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
input_mask: Optional[Tensor] = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
B, S = input_ids.shape
|
||||||
|
|
||||||
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||||
|
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
|
||||||
|
|
||||||
|
hidden_states = self.norm(x)
|
||||||
|
|
||||||
|
if self.pooling_type == "cls":
|
||||||
|
pooled = hidden_states[:, 0]
|
||||||
|
elif self.pooling_type == "last":
|
||||||
|
if input_mask is not None:
|
||||||
|
lengths = input_mask.sum(dim=1) - 1
|
||||||
|
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
|
||||||
|
else:
|
||||||
|
pooled = hidden_states[:, -1]
|
||||||
|
else:
|
||||||
|
if input_mask is not None:
|
||||||
|
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
||||||
|
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
|
||||||
|
min=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pooled = hidden_states.mean(dim=1)
|
||||||
|
|
||||||
|
if self.normalize_embeddings:
|
||||||
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
|
||||||
|
|
||||||
|
return pooled
|
||||||
|
|
@ -1,19 +1,17 @@
|
||||||
from typing import Any, Mapping, Optional
|
from typing import Any, Dict, Mapping, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
from astrai.inference.core.cache import KvcacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
DecoderBlock,
|
from astrai.model.components.embedding import Embedding
|
||||||
Embedding,
|
from astrai.model.components.linear import Linear
|
||||||
Linear,
|
from astrai.model.components.norm import RMSNorm
|
||||||
RMSNorm,
|
from astrai.model.components.rope import RotaryEmbedding
|
||||||
RotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
|
|
@ -28,35 +26,38 @@ def process_attention_mask(
|
||||||
return input_mask
|
return input_mask
|
||||||
|
|
||||||
device = input_tensor.device
|
device = input_tensor.device
|
||||||
dtype = input_tensor.dtype
|
B = input_tensor.size(0)
|
||||||
B, S = input_tensor.size()[:2]
|
|
||||||
T = position_ids.max().item() + 1
|
T = position_ids.max().item() + 1
|
||||||
|
|
||||||
if input_mask is None:
|
if input_mask is None:
|
||||||
if position_ids.min().item() == 0 and is_causal:
|
if position_ids.min().item() == 0 and is_causal:
|
||||||
return None
|
return None
|
||||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
|
||||||
|
|
||||||
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||||
|
attend = attend & causal
|
||||||
|
|
||||||
return torch.full(
|
return attend.unsqueeze(1)
|
||||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
|
||||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("transformer")
|
@AutoModel.register("autoregressive_lm")
|
||||||
class Transformer(AutoModel):
|
class AutoRegressiveLM(AutoModel):
|
||||||
"""Transformer language model with paged KV cache."""
|
"""Autoregressive language model with paged KV cache."""
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: AutoRegressiveLMConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
rope_dim = (
|
||||||
|
config.qk_rope_head_dim
|
||||||
|
if config.attn_type == "mla"
|
||||||
|
else config.dim // config.n_heads
|
||||||
|
)
|
||||||
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
self.rotary_embedding = RotaryEmbedding(
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
config.dim // config.n_heads, config.max_len
|
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||||
)
|
)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
|
|
@ -71,6 +72,15 @@ class Transformer(AutoModel):
|
||||||
config.use_qk_norm,
|
config.use_qk_norm,
|
||||||
config.use_gated_attention,
|
config.use_gated_attention,
|
||||||
layer_id,
|
layer_id,
|
||||||
|
attn_type=config.attn_type,
|
||||||
|
ffn_type=config.ffn_type,
|
||||||
|
n_routed_experts=config.n_routed_experts,
|
||||||
|
n_shared_experts=config.n_shared_experts,
|
||||||
|
n_activated_experts=config.n_activated_experts,
|
||||||
|
topk_method=config.topk_method,
|
||||||
|
kv_lora_rank=config.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.n_layers)
|
for layer_id in range(config.n_layers)
|
||||||
]
|
]
|
||||||
|
|
@ -79,15 +89,14 @@ class Transformer(AutoModel):
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight is True:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
self._init_weights()
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self, module):
|
||||||
for param in self.parameters():
|
if hasattr(module, "reset_parameters"):
|
||||||
if param.dim() > 1:
|
module.reset_parameters()
|
||||||
nn.init.normal_(param, mean=0.0, std=0.006)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
lm_head_key = "lm_head.weight"
|
lm_head_key = "lm_head.weight"
|
||||||
|
|
@ -95,7 +104,7 @@ class Transformer(AutoModel):
|
||||||
|
|
||||||
state_dict = dict(state_dict)
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight is True:
|
||||||
# same tensor for embed and lm_head
|
# same tensor for embed and lm_head
|
||||||
if embed_key in state_dict:
|
if embed_key in state_dict:
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
|
|
@ -111,7 +120,7 @@ class Transformer(AutoModel):
|
||||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight is True:
|
||||||
lm_head_key = prefix + "lm_head.weight"
|
lm_head_key = prefix + "lm_head.weight"
|
||||||
if lm_head_key in state_dict:
|
if lm_head_key in state_dict:
|
||||||
del state_dict[lm_head_key]
|
del state_dict[lm_head_key]
|
||||||
|
|
@ -124,7 +133,7 @@ class Transformer(AutoModel):
|
||||||
input_mask: Optional[Tensor] = None,
|
input_mask: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
position_ids: Optional[Tensor] = None,
|
position_ids: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Dict[str, Tensor]:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,13 @@
|
||||||
|
from astrai.parallel.executor import (
|
||||||
|
AccumOptimizer,
|
||||||
|
AccumScheduler,
|
||||||
|
BaseExecutor,
|
||||||
|
DDPExecutor,
|
||||||
|
ExecutorFactory,
|
||||||
|
FSDPExecutor,
|
||||||
|
GradientState,
|
||||||
|
NoneExecutor,
|
||||||
|
)
|
||||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||||
from astrai.parallel.setup import (
|
from astrai.parallel.setup import (
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
|
@ -17,4 +27,12 @@ __all__ = [
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
"ColumnParallelLinear",
|
"ColumnParallelLinear",
|
||||||
|
"ExecutorFactory",
|
||||||
|
"BaseExecutor",
|
||||||
|
"GradientState",
|
||||||
|
"AccumOptimizer",
|
||||||
|
"AccumScheduler",
|
||||||
|
"NoneExecutor",
|
||||||
|
"DDPExecutor",
|
||||||
|
"FSDPExecutor",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,272 @@
|
||||||
|
"""Unified training executor — parallel strategy + gradient accumulation."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.parallel.setup import get_rank, get_world_size
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientState:
|
||||||
|
def __init__(self, grad_accum_steps: int = 1):
|
||||||
|
self.num_steps = max(grad_accum_steps, 1)
|
||||||
|
self._step: int = 0
|
||||||
|
self._sync_gradients: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_gradients(self) -> bool:
|
||||||
|
return self._sync_gradients
|
||||||
|
|
||||||
|
def _do_sync(self):
|
||||||
|
self._step += 1
|
||||||
|
self._sync_gradients = self._step % self.num_steps == 0
|
||||||
|
|
||||||
|
|
||||||
|
class AccumOptimizer:
|
||||||
|
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.gradient_state = gradient_state
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
if self.gradient_state.sync_gradients:
|
||||||
|
self.optimizer.step(closure)
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
if self.gradient_state.sync_gradients:
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.optimizer.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, d):
|
||||||
|
self.optimizer.load_state_dict(d)
|
||||||
|
|
||||||
|
|
||||||
|
class AccumScheduler:
|
||||||
|
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.gradient_state = gradient_state
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
if self.gradient_state.sync_gradients:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.scheduler.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, d):
|
||||||
|
self.scheduler.load_state_dict(d)
|
||||||
|
|
||||||
|
def get_last_lr(self):
|
||||||
|
return self.scheduler.get_last_lr()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExecutor:
|
||||||
|
def __init__(self, grad_accum_steps: int = 1):
|
||||||
|
self.gradient_state = GradientState(grad_accum_steps)
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
dataloader: Optional[DataLoader] = None,
|
||||||
|
scheduler: Optional[LRScheduler] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
|
||||||
|
]:
|
||||||
|
model = self._prepare_model(model)
|
||||||
|
if optimizer is not None:
|
||||||
|
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
||||||
|
return model, optimizer, dataloader, scheduler
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def accumulate(self, model: nn.Module):
|
||||||
|
self.gradient_state._do_sync()
|
||||||
|
if not self.gradient_state.sync_gradients:
|
||||||
|
with self._no_sync(model):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
def backward(self, loss: torch.Tensor):
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module):
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_distributed(self) -> bool:
|
||||||
|
return get_world_size() > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_gradients(self) -> bool:
|
||||||
|
return self.gradient_state.sync_gradients
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grad_accum_steps(self) -> int:
|
||||||
|
return self.gradient_state.num_steps
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ExecutorFactory.register("none")
|
||||||
|
class NoneExecutor(BaseExecutor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ExecutorFactory.register("ddp")
|
||||||
|
class DDPExecutor(BaseExecutor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grad_accum_steps: int = 1,
|
||||||
|
dim: int = 0,
|
||||||
|
broadcast_buffers: bool = True,
|
||||||
|
init_sync: bool = True,
|
||||||
|
process_group=None,
|
||||||
|
bucket_cap_mb: int = 25,
|
||||||
|
find_unused_parameters: bool = False,
|
||||||
|
check_reduction: bool = False,
|
||||||
|
gradient_as_bucket_view: bool = False,
|
||||||
|
static_graph: bool = False,
|
||||||
|
delay_all_reduce_named_params=None,
|
||||||
|
param_to_hook_all_reduce=None,
|
||||||
|
mixed_precision=None,
|
||||||
|
device_mesh=None,
|
||||||
|
):
|
||||||
|
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||||
|
self._ddp_kwargs = dict(
|
||||||
|
dim=dim,
|
||||||
|
broadcast_buffers=broadcast_buffers,
|
||||||
|
init_sync=init_sync,
|
||||||
|
process_group=process_group,
|
||||||
|
bucket_cap_mb=bucket_cap_mb,
|
||||||
|
find_unused_parameters=find_unused_parameters,
|
||||||
|
check_reduction=check_reduction,
|
||||||
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||||
|
static_graph=static_graph,
|
||||||
|
delay_all_reduce_named_params=delay_all_reduce_named_params,
|
||||||
|
param_to_hook_all_reduce=param_to_hook_all_reduce,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if not self.use_distributed:
|
||||||
|
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
||||||
|
return model
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
|
||||||
|
model = DDP(
|
||||||
|
model,
|
||||||
|
device_ids=[local_rank],
|
||||||
|
output_device=local_rank,
|
||||||
|
**self._ddp_kwargs,
|
||||||
|
)
|
||||||
|
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.no_sync()
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module):
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.module.state_dict()
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@ExecutorFactory.register("fsdp")
|
||||||
|
class FSDPExecutor(BaseExecutor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grad_accum_steps: int = 1,
|
||||||
|
process_group=None,
|
||||||
|
sharding_strategy=None,
|
||||||
|
cpu_offload=None,
|
||||||
|
auto_wrap_policy=None,
|
||||||
|
backward_prefetch=None,
|
||||||
|
mixed_precision=None,
|
||||||
|
ignored_modules=None,
|
||||||
|
param_init_fn=None,
|
||||||
|
sync_module_states: bool = False,
|
||||||
|
forward_prefetch: bool = False,
|
||||||
|
limit_all_gathers: bool = True,
|
||||||
|
ignored_states=None,
|
||||||
|
device_mesh=None,
|
||||||
|
):
|
||||||
|
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||||
|
self._fsdp_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
process_group=process_group,
|
||||||
|
sharding_strategy=sharding_strategy,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
auto_wrap_policy=auto_wrap_policy,
|
||||||
|
backward_prefetch=backward_prefetch,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
ignored_modules=ignored_modules,
|
||||||
|
param_init_fn=param_init_fn,
|
||||||
|
sync_module_states=sync_module_states,
|
||||||
|
forward_prefetch=forward_prefetch,
|
||||||
|
limit_all_gathers=limit_all_gathers,
|
||||||
|
use_orig_params=True,
|
||||||
|
ignored_states=ignored_states,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
self._original_model: Optional[nn.Module] = None
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if not self.use_distributed:
|
||||||
|
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
||||||
|
return model
|
||||||
|
self._original_model = model
|
||||||
|
device_id = torch.device("cuda", get_rank())
|
||||||
|
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
||||||
|
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
if isinstance(model, FSDP):
|
||||||
|
return model.no_sync()
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module):
|
||||||
|
if isinstance(model, FSDP) and self.use_distributed:
|
||||||
|
with FSDP.state_dict_type(
|
||||||
|
model,
|
||||||
|
StateDictType.FULL_STATE_DICT,
|
||||||
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||||
|
):
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
return model.state_dict()
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
@ -30,6 +31,7 @@ def get_rank() -> int:
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
local_rank: int,
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
|
|
@ -41,14 +43,18 @@ def setup_parallel(
|
||||||
return
|
return
|
||||||
|
|
||||||
if world_size <= 1:
|
if world_size <= 1:
|
||||||
|
device_id = torch.device(device_type, local_rank)
|
||||||
|
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
device_id = torch.device(device_type, rank)
|
device_id = torch.device(device_type, local_rank)
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ["MASTER_PORT"] = master_port
|
os.environ["MASTER_PORT"] = master_port
|
||||||
os.environ["LOCAL_RANK"] = str(rank)
|
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
|
|
@ -90,7 +96,7 @@ def only_on_rank(rank, sync=False):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def wrapper_spawn_func(
|
def _run_single_rank(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
|
|
@ -100,10 +106,10 @@ def wrapper_spawn_func(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
with setup_parallel(
|
with setup_parallel(
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
local_rank=rank,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
|
|
@ -111,11 +117,99 @@ def wrapper_spawn_func(
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in rank {rank}: {e}")
|
class LaunchStrategy(ABC):
|
||||||
|
"""Strategy for launching a function in a distributed context."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
world_size: int,
|
||||||
|
backend: str,
|
||||||
|
master_addr: str,
|
||||||
|
master_port: str,
|
||||||
|
device_type: str,
|
||||||
|
start_method: str,
|
||||||
|
):
|
||||||
|
self.world_size = world_size
|
||||||
|
self.backend = backend
|
||||||
|
self.master_addr = master_addr
|
||||||
|
self.master_port = master_port
|
||||||
|
self.device_type = device_type
|
||||||
|
self.start_method = start_method
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def launch(self, func: Callable, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TorchrunStrategy(LaunchStrategy):
|
||||||
|
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
|
||||||
|
|
||||||
|
def launch(self, func: Callable, **kwargs):
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
||||||
|
with setup_parallel(
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
local_rank=local_rank,
|
||||||
|
backend=self.backend,
|
||||||
|
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
|
||||||
|
master_port=os.environ.get("MASTER_PORT", self.master_port),
|
||||||
|
device_type=self.device_type,
|
||||||
|
):
|
||||||
|
func(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalStrategy(LaunchStrategy):
|
||||||
|
"""Local launcher — single-process or mp.start_processes."""
|
||||||
|
|
||||||
|
def launch(self, func: Callable, **kwargs):
|
||||||
|
args = (
|
||||||
|
self.world_size,
|
||||||
|
self.backend,
|
||||||
|
self.master_addr,
|
||||||
|
self.master_port,
|
||||||
|
self.device_type,
|
||||||
|
func,
|
||||||
|
kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.world_size == 1:
|
||||||
|
_run_single_rank(0, *args)
|
||||||
|
return
|
||||||
|
|
||||||
|
ctx = mp.start_processes(
|
||||||
|
_run_single_rank,
|
||||||
|
args=args,
|
||||||
|
nprocs=self.world_size,
|
||||||
|
start_method=self.start_method,
|
||||||
|
join=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
while not ctx.join():
|
||||||
|
pass
|
||||||
|
except BaseException:
|
||||||
|
for p in ctx.processes:
|
||||||
|
p.terminate()
|
||||||
|
ctx.join()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_launcher() -> str:
|
||||||
|
"""Detect the distributed launcher from environment.
|
||||||
|
|
||||||
|
Returns one of: "torchelastic", "torchrun", "external", "local".
|
||||||
|
"""
|
||||||
|
if dist.is_torchelastic_launched():
|
||||||
|
return "torchelastic"
|
||||||
|
if "LOCAL_WORLD_SIZE" in os.environ:
|
||||||
|
return "torchrun"
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
return "external"
|
||||||
|
return "local"
|
||||||
|
|
||||||
|
|
||||||
def spawn_parallel_fn(
|
def spawn_parallel_fn(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
|
@ -123,39 +217,16 @@ def spawn_parallel_fn(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
|
start_method: str = "spawn",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# clear environment variables
|
launcher = _detect_launcher()
|
||||||
for key in [
|
if launcher in ("torchelastic", "torchrun", "external"):
|
||||||
"MASTER_ADDR",
|
strategy = TorchrunStrategy(
|
||||||
"MASTER_PORT",
|
world_size, backend, master_addr, master_port, device_type, start_method
|
||||||
"RANK",
|
|
||||||
"WORLD_SIZE",
|
|
||||||
"LOCAL_RANK",
|
|
||||||
"LOCAL_DEVICE",
|
|
||||||
]:
|
|
||||||
if key in os.environ:
|
|
||||||
del os.environ[key]
|
|
||||||
|
|
||||||
if world_size == 1:
|
|
||||||
device_id = torch.device(device_type, 0)
|
|
||||||
os.environ["LOCAL_RANK"] = "0"
|
|
||||||
os.environ["WORLD_SIZE"] = "1"
|
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
|
||||||
|
|
||||||
func(**kwargs)
|
|
||||||
return
|
|
||||||
|
|
||||||
wrapper_spawn_func_args = (
|
|
||||||
world_size,
|
|
||||||
backend,
|
|
||||||
master_addr,
|
|
||||||
master_port,
|
|
||||||
device_type,
|
|
||||||
func,
|
|
||||||
kwargs,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
mp.spawn(
|
strategy = LocalStrategy(
|
||||||
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
|
world_size, backend, master_addr, master_port, device_type, start_method
|
||||||
)
|
)
|
||||||
|
strategy.launch(func, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
from astrai.preprocessing.builder import (
|
||||||
|
BaseMaskBuilder,
|
||||||
|
MaskBuilderFactory,
|
||||||
|
SectionedMaskBuilder,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseMaskBuilder",
|
||||||
|
"MaskBuilderFactory",
|
||||||
|
"SectionedMaskBuilder",
|
||||||
|
"Pipeline",
|
||||||
|
"filter_by_length",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,338 @@
|
||||||
|
"""Mask building strategies for preprocessing pipeline.
|
||||||
|
|
||||||
|
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||||
|
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
|
||||||
|
for single-output or ``input.sources`` for multi-output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMaskBuilder(ABC):
|
||||||
|
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
||||||
|
|
||||||
|
Returns ``None`` to skip the item entirely.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: type):
|
||||||
|
if not issubclass(component_cls, BaseMaskBuilder):
|
||||||
|
raise TypeError(
|
||||||
|
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||||
|
if not domain_key:
|
||||||
|
return "__default__"
|
||||||
|
val = item.get(domain_key, "__default__")
|
||||||
|
return val if isinstance(val, str) else "__default__"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_action(action: str, role: str, config) -> str:
|
||||||
|
"""Resolve action to "train" or "mask".
|
||||||
|
|
||||||
|
- ``"train"`` / ``"mask"`` → literal
|
||||||
|
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||||||
|
"""
|
||||||
|
if action == "$role":
|
||||||
|
return config.mask.get(role, config.mask_default)
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@MaskBuilderFactory.register("sectioned")
|
||||||
|
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
|
"""Config-driven builder supporting single and multi-output modes.
|
||||||
|
|
||||||
|
Single-output (backward-compatible)::
|
||||||
|
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "messages", "action": "$role", "template": true}
|
||||||
|
]}}
|
||||||
|
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
||||||
|
|
||||||
|
Multi-output (DPO / GRPO)::
|
||||||
|
|
||||||
|
{"input": {"sources": {
|
||||||
|
"chosen": {"sections": [
|
||||||
|
{"field": "chosen", "action": "$role", "template": true}
|
||||||
|
]},
|
||||||
|
"rejected": {"sections": [
|
||||||
|
{"field": "rejected", "action": "$role", "template": true}
|
||||||
|
]}
|
||||||
|
}}}
|
||||||
|
→ {"chosen": [...], "chosen_mask": [...],
|
||||||
|
"rejected": [...], "rejected_mask": [...], "domain": "..."}
|
||||||
|
|
||||||
|
Output spec fields::
|
||||||
|
|
||||||
|
sections – list of section specs (same format as single-output)
|
||||||
|
list_field – True when the JSONL field holds a list of values to
|
||||||
|
tokenise individually and concatenate (GRPO responses)
|
||||||
|
mask_key – explicit output key for the loss mask
|
||||||
|
(default: ``"{output_key}_mask"``)
|
||||||
|
dtype – explicit tensor dtype for this output key
|
||||||
|
(default: "int32")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
sources_spec = getattr(config.input, "sources", None)
|
||||||
|
if sources_spec:
|
||||||
|
return self._build_multi(item, sources_spec, config, tokenizer)
|
||||||
|
return self._build_single(item, config, tokenizer)
|
||||||
|
|
||||||
|
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
sections = config.input.sections
|
||||||
|
if not sections:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ids, mask = self._process_sections(
|
||||||
|
item, sections, config, tokenizer, is_top_level=True
|
||||||
|
)
|
||||||
|
if ids is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result: dict = {
|
||||||
|
"sequence": ids,
|
||||||
|
"domain": _extract_domain(item, config.output.domain_key),
|
||||||
|
}
|
||||||
|
if not all(m == 1 for m in mask):
|
||||||
|
result["loss_mask"] = mask
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _build_multi(
|
||||||
|
self, item: dict, sources_spec: dict, config, tokenizer
|
||||||
|
) -> Optional[dict]:
|
||||||
|
result: dict = {}
|
||||||
|
any_output = False
|
||||||
|
|
||||||
|
for output_key, spec in sources_spec.items():
|
||||||
|
sections = spec.get("sections", [])
|
||||||
|
if not sections:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._is_value_section(sections):
|
||||||
|
ids = self._extract_raw_value(item, sections)
|
||||||
|
if ids is None:
|
||||||
|
continue
|
||||||
|
result[output_key] = ids
|
||||||
|
any_output = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
list_field = spec.get("list_field", False)
|
||||||
|
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
||||||
|
|
||||||
|
if list_field:
|
||||||
|
ids, mask = self._process_list_field(item, sections, config, tokenizer)
|
||||||
|
else:
|
||||||
|
ids, mask = self._process_sections(
|
||||||
|
item, sections, config, tokenizer, is_top_level=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if ids is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result[output_key] = ids
|
||||||
|
if not all(m == 1 for m in mask):
|
||||||
|
result[mask_key] = mask
|
||||||
|
elif "mask_key" in spec:
|
||||||
|
result[mask_key] = mask
|
||||||
|
|
||||||
|
any_output = True
|
||||||
|
|
||||||
|
if not any_output:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result["domain"] = _extract_domain(item, config.output.domain_key)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_value_section(sections: list) -> bool:
|
||||||
|
return len(sections) == 1 and sections[0].get("action") == "value"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_raw_value(item: dict, sections: list):
|
||||||
|
"""Extract a raw value from a JSONL field without tokenisation.
|
||||||
|
|
||||||
|
Used for GRPO rewards where the field contains float values.
|
||||||
|
"""
|
||||||
|
sec = sections[0]
|
||||||
|
field = sec["field"]
|
||||||
|
raw = item.get(field)
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
if isinstance(raw, list):
|
||||||
|
return [float(v) for v in raw]
|
||||||
|
return [float(raw)]
|
||||||
|
|
||||||
|
def _process_sections(
|
||||||
|
self,
|
||||||
|
item: dict,
|
||||||
|
sections: list,
|
||||||
|
config,
|
||||||
|
tokenizer,
|
||||||
|
*,
|
||||||
|
is_top_level: bool = False,
|
||||||
|
):
|
||||||
|
"""Process a list of sections into ``(ids, loss_mask)``.
|
||||||
|
|
||||||
|
Returns ``(None, None)`` if the item should be skipped.
|
||||||
|
"""
|
||||||
|
all_ids: list[int] = []
|
||||||
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
|
has_template = any(s.get("template") for s in sections)
|
||||||
|
is_text_config = not has_template and all(
|
||||||
|
s["action"] == "train" for s in sections
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_top_level and has_template and tokenizer.bos_token_id is not None:
|
||||||
|
all_ids.append(tokenizer.bos_token_id)
|
||||||
|
loss_mask.append(0)
|
||||||
|
|
||||||
|
first_section = True
|
||||||
|
for sec in sections:
|
||||||
|
field = sec["field"]
|
||||||
|
action = sec["action"]
|
||||||
|
use_template = sec.get("template", False)
|
||||||
|
add_special = sec.get(
|
||||||
|
"add_special_tokens", not use_template and first_section
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_template:
|
||||||
|
success = self._append_template_section(
|
||||||
|
item, field, action, tokenizer, config, all_ids, loss_mask
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
success = self._append_text_section(
|
||||||
|
item,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
add_special,
|
||||||
|
is_text_config,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
continue
|
||||||
|
|
||||||
|
first_section = False
|
||||||
|
|
||||||
|
max_len = config.preprocessing.max_seq_len
|
||||||
|
all_ids = all_ids[:max_len]
|
||||||
|
loss_mask = loss_mask[: len(all_ids)]
|
||||||
|
|
||||||
|
if not all_ids:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if is_top_level and has_template and len(all_ids) <= 1:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
return all_ids, loss_mask
|
||||||
|
|
||||||
|
def _append_template_section(
|
||||||
|
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
||||||
|
):
|
||||||
|
messages = item.get(field)
|
||||||
|
if not isinstance(messages, list) or not messages:
|
||||||
|
return False
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
act = _resolve_action(action, role, config)
|
||||||
|
rendered = tokenizer.apply_chat_template(
|
||||||
|
[msg], tokenize=False, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||||
|
all_ids.extend(ids)
|
||||||
|
val = 1 if act == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _append_text_section(
|
||||||
|
self,
|
||||||
|
item,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
add_special,
|
||||||
|
is_text_config,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
):
|
||||||
|
text = str(item.get(field, ""))
|
||||||
|
if not text.strip():
|
||||||
|
return False
|
||||||
|
if is_text_config:
|
||||||
|
pp = config.preprocessing
|
||||||
|
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||||||
|
return False
|
||||||
|
if len(text) > pp.max_chars:
|
||||||
|
return False
|
||||||
|
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||||||
|
all_ids.extend(ids)
|
||||||
|
val = 1 if action == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
|
||||||
|
all_ids: list[int] = []
|
||||||
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
|
for sec in sections:
|
||||||
|
field = sec["field"]
|
||||||
|
action = sec["action"]
|
||||||
|
use_template = sec.get("template", False)
|
||||||
|
|
||||||
|
values = item.get(field)
|
||||||
|
if not isinstance(values, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for val in values:
|
||||||
|
if use_template:
|
||||||
|
if isinstance(val, list):
|
||||||
|
wrapper = {field: val}
|
||||||
|
self._append_template_section(
|
||||||
|
wrapper,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrapper = {field: str(val)}
|
||||||
|
self._append_text_section(
|
||||||
|
wrapper,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_len = config.preprocessing.max_seq_len
|
||||||
|
all_ids = all_ids[:max_len]
|
||||||
|
loss_mask = loss_mask[: len(all_ids)]
|
||||||
|
|
||||||
|
if not all_ids:
|
||||||
|
return None, None
|
||||||
|
return all_ids, loss_mask
|
||||||
|
|
@ -0,0 +1,257 @@
|
||||||
|
"""Config-driven JSONL preprocessing pipeline.
|
||||||
|
|
||||||
|
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||||
|
sharding and flush to ``.h5`` / ``.bin`` storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import chain
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
|
from astrai.dataset.storage import save_bin, save_h5
|
||||||
|
from astrai.preprocessing.builder import SectionedMaskBuilder
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||||
|
"bool": torch.bool,
|
||||||
|
"uint8": torch.uint8,
|
||||||
|
"int8": torch.int8,
|
||||||
|
"int16": torch.int16,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||||
|
return min_len <= len(text) <= max_len
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||||
|
if len(seq) <= max_len:
|
||||||
|
return seq
|
||||||
|
if mode == "keep_end":
|
||||||
|
return seq[-max_len:]
|
||||||
|
return seq[:max_len]
|
||||||
|
|
||||||
|
|
||||||
|
def pack_sequences(
|
||||||
|
sequences: List[list],
|
||||||
|
max_packed_len: int,
|
||||||
|
strategy: str,
|
||||||
|
truncation_mode: str,
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""Pack *sequences* into bins and return a reorder plan.
|
||||||
|
|
||||||
|
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
|
||||||
|
All keys (sequence, loss_mask, …) must be reordered and truncated
|
||||||
|
identically according to this plan.
|
||||||
|
|
||||||
|
Supported *strategy* values:
|
||||||
|
|
||||||
|
- ``"simple"``: sequential, no reordering.
|
||||||
|
- ``"bfd"``: best-fit decreasing bin packing.
|
||||||
|
"""
|
||||||
|
n = len(sequences)
|
||||||
|
if strategy == "simple":
|
||||||
|
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
|
||||||
|
|
||||||
|
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||||
|
bins: List[List[int]] = []
|
||||||
|
bin_lengths: List[int] = []
|
||||||
|
|
||||||
|
for orig_idx in order:
|
||||||
|
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||||
|
|
||||||
|
best_bin = None
|
||||||
|
best_remain = max_packed_len + 1
|
||||||
|
for i, bl in enumerate(bin_lengths):
|
||||||
|
remain = max_packed_len - bl
|
||||||
|
if seq_len <= remain < best_remain:
|
||||||
|
best_remain = remain
|
||||||
|
best_bin = i
|
||||||
|
|
||||||
|
if best_bin is not None:
|
||||||
|
bins[best_bin].append(orig_idx)
|
||||||
|
bin_lengths[best_bin] += seq_len
|
||||||
|
else:
|
||||||
|
bins.append([orig_idx])
|
||||||
|
bin_lengths.append(seq_len)
|
||||||
|
|
||||||
|
plan: List[Tuple[int, int]] = []
|
||||||
|
for bin_indices in bins:
|
||||||
|
for orig_idx in bin_indices:
|
||||||
|
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||||
|
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PipelineConfig,
|
||||||
|
input_paths: list[str],
|
||||||
|
output_dir: str,
|
||||||
|
tokenizer_path: str,
|
||||||
|
):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
self.config = config
|
||||||
|
self.paths = input_paths
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.tokenizer_path = tokenizer_path
|
||||||
|
|
||||||
|
self.mask_builder = SectionedMaskBuilder()
|
||||||
|
|
||||||
|
def transform(self, item: dict) -> Optional[dict]:
|
||||||
|
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||||
|
domains: dict = defaultdict(lambda: defaultdict(list))
|
||||||
|
total_tokens = 0
|
||||||
|
shard_idx: dict[str, int] = defaultdict(int)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
pp = self.config.preprocessing
|
||||||
|
|
||||||
|
for item in tqdm.tqdm(
|
||||||
|
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
||||||
|
):
|
||||||
|
if pp.max_items and count >= pp.max_items:
|
||||||
|
break
|
||||||
|
|
||||||
|
result = self.transform(item)
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
domain = result.pop("domain", "__default__")
|
||||||
|
|
||||||
|
is_multi = bool(getattr(self.config.input, "sources", None))
|
||||||
|
if is_multi:
|
||||||
|
ids = self._primary_ids(result)
|
||||||
|
else:
|
||||||
|
ids = result.pop("sequence")
|
||||||
|
result["sequence"] = ids
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bucket = domains[domain]
|
||||||
|
self._align_bucket(bucket, result, ids, is_multi)
|
||||||
|
for key, val in result.items():
|
||||||
|
bucket[key].append(val)
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
total_tokens += len(ids)
|
||||||
|
|
||||||
|
if total_tokens >= self.config.output.max_tokens_per_shard:
|
||||||
|
self._flush(domains, shard_idx)
|
||||||
|
domains.clear()
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if total_tokens > 0:
|
||||||
|
self._flush(domains, shard_idx)
|
||||||
|
|
||||||
|
print(f"Done. {count} documents tokenized.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _primary_ids(result: dict) -> list:
|
||||||
|
"""Return the first list-valued entry in *result* as the primary id
|
||||||
|
sequence for token counting."""
|
||||||
|
for val in result.values():
|
||||||
|
if isinstance(val, list) and val and isinstance(val[0], int):
|
||||||
|
return val
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
|
||||||
|
"""Pad previously-accumulated keys that are missing from *result*."""
|
||||||
|
for key in list(bucket.keys()):
|
||||||
|
if key in result:
|
||||||
|
continue
|
||||||
|
if is_multi:
|
||||||
|
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
|
||||||
|
bucket[key].append(pad)
|
||||||
|
else:
|
||||||
|
bucket[key].append([1] * len(ids))
|
||||||
|
|
||||||
|
def _iter_items(self):
|
||||||
|
for path in self.paths:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
def _flush(self, domains, shard_idx):
|
||||||
|
for domain, keys in domains.items():
|
||||||
|
idx = shard_idx[domain]
|
||||||
|
chunk_dir = os.path.join(self.output_dir, domain)
|
||||||
|
|
||||||
|
pp = self.config.preprocessing
|
||||||
|
if pp.packing_strategy != "simple" and "sequence" in keys:
|
||||||
|
plan = pack_sequences(
|
||||||
|
keys["sequence"],
|
||||||
|
pp.max_packed_len,
|
||||||
|
pp.packing_strategy,
|
||||||
|
pp.truncation_mode,
|
||||||
|
)
|
||||||
|
reordered = defaultdict(list)
|
||||||
|
for orig_idx, truncated_len in plan:
|
||||||
|
for k, vals in keys.items():
|
||||||
|
reordered[k].append(
|
||||||
|
_truncate(
|
||||||
|
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
|
||||||
|
)
|
||||||
|
)
|
||||||
|
keys = reordered
|
||||||
|
|
||||||
|
tensors = {}
|
||||||
|
for key, ids_list in keys.items():
|
||||||
|
dt = _STR_TO_DTYPE.get(
|
||||||
|
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||||
|
)
|
||||||
|
tensors[key] = [
|
||||||
|
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||||
|
]
|
||||||
|
|
||||||
|
pid_mode = self.config.output.position_ids_mode
|
||||||
|
if pid_mode and pid_mode != "none" and "sequence" in tensors:
|
||||||
|
pos_ids = []
|
||||||
|
if pid_mode == "doc_reset":
|
||||||
|
for item in keys["sequence"]:
|
||||||
|
pos_ids.extend(range(len(item)))
|
||||||
|
else:
|
||||||
|
total = sum(len(item) for item in keys["sequence"])
|
||||||
|
pos_ids = list(range(total))
|
||||||
|
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
||||||
|
|
||||||
|
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
|
||||||
|
fmt = self.config.output.storage_format
|
||||||
|
if fmt == "bin":
|
||||||
|
save_bin(shard_path, tensors)
|
||||||
|
else:
|
||||||
|
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||||
|
shard_idx[domain] = idx + 1
|
||||||
|
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
|
||||||
|
tqdm.tqdm.write(
|
||||||
|
f" saved {domain}/shard_{idx:04d} "
|
||||||
|
f"({tensors[first_key][0].numel():,} tokens)"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class OptimizerProtocol(Protocol):
|
||||||
|
def step(self, closure=None): ...
|
||||||
|
def zero_grad(self): ...
|
||||||
|
@property
|
||||||
|
def param_groups(self) -> Any: ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, d: dict): ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SchedulerProtocol(Protocol):
|
||||||
|
def step(self): ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, d: dict): ...
|
||||||
|
def get_last_lr(self): ...
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -8,70 +11,172 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from astrai.parallel.setup import get_rank
|
from astrai.parallel.setup import get_rank
|
||||||
|
|
||||||
|
_META_FILE = "meta.json"
|
||||||
|
_CONFIG_FILE = "config.json"
|
||||||
|
_WEIGHTS_FILE = "model.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
||||||
|
st.save_file(state_dict, str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
return st.load_file(str(path))
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
state_dict = st.load_file(str(path))
|
||||||
|
else:
|
||||||
|
state_dict = {}
|
||||||
|
tmp = [state_dict]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(data: dict, path: Union[str, Path]):
|
||||||
|
with open(str(path), "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
tmp = [data]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
|
|
||||||
|
def save_torch(obj: Any, path: Union[str, Path]):
|
||||||
|
torch.save(obj, str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
rank = get_rank()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
raw = f.read()
|
||||||
|
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
||||||
|
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
||||||
|
else:
|
||||||
|
num_bytes = torch.tensor([0], dtype=torch.long)
|
||||||
|
|
||||||
|
dist.broadcast(num_bytes, src=0)
|
||||||
|
|
||||||
|
if rank != 0:
|
||||||
|
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
||||||
|
|
||||||
|
dist.broadcast(data_tensor, src=0)
|
||||||
|
|
||||||
|
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
||||||
|
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(config: dict, state_dict: dict, save_directory: str):
|
||||||
|
save_path = Path(save_directory)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_json(config, save_path / _CONFIG_FILE)
|
||||||
|
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_config(save_directory: str) -> dict:
|
||||||
|
return load_json(Path(save_directory) / _CONFIG_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(save_directory: str) -> dict:
|
||||||
|
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
path = Path(path)
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
return load_safetensors(path)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
state_dict = load_safetensors(path)
|
||||||
|
specs = [
|
||||||
|
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||||
|
for k in sorted(state_dict)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
state_dict = {}
|
||||||
|
specs = []
|
||||||
|
|
||||||
|
specs_list = [specs]
|
||||||
|
dist.broadcast_object_list(specs_list, src=0)
|
||||||
|
specs = specs_list[0]
|
||||||
|
|
||||||
|
for key, shape, dtype_name in specs:
|
||||||
|
dtype = getattr(torch, dtype_name)
|
||||||
|
if rank != 0:
|
||||||
|
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
||||||
|
else:
|
||||||
|
tensor = state_dict[key].contiguous().cpu()
|
||||||
|
dist.broadcast(tensor, src=0)
|
||||||
|
if rank != 0:
|
||||||
|
state_dict[key] = tensor
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
def __init__(
|
state_dict: Dict[str, Any] = field(default_factory=dict)
|
||||||
self,
|
epoch: int = 0
|
||||||
state_dict: Dict[str, Any],
|
iteration: int = 0
|
||||||
epoch: int = 0,
|
extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
iteration: int = 0,
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
extra: Optional[Dict[str, Any]] = None,
|
config: Dict[str, Any] = field(default_factory=dict)
|
||||||
):
|
|
||||||
self.state_dict = state_dict
|
|
||||||
self.epoch = epoch
|
|
||||||
self.iteration = iteration
|
|
||||||
self.extra = extra or {}
|
|
||||||
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
save_dir: str,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
|
def save(self, save_dir: str):
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
rank = get_rank()
|
if get_rank() != 0:
|
||||||
if rank == 0:
|
return
|
||||||
|
|
||||||
meta = {
|
meta = {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"iteration": self.iteration,
|
"iteration": self.iteration,
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
|
**self.meta,
|
||||||
}
|
}
|
||||||
with open(save_path / "meta.json", "w") as f:
|
save_json(meta, save_path / _META_FILE)
|
||||||
json.dump(meta, f, indent=2)
|
save_json(self.config, save_path / _CONFIG_FILE)
|
||||||
|
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
for key, value in self.extra.items():
|
||||||
if self.extra:
|
save_torch(value, save_path / f"{key}.pt")
|
||||||
torch.save(self.extra, save_path / "extra.pt")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||||
cls,
|
|
||||||
save_dir: str,
|
|
||||||
) -> "Checkpoint":
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = {}
|
meta = load_json(save_path / _META_FILE, broadcast)
|
||||||
if rank == 0:
|
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
||||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
extra = {}
|
||||||
meta_list = [meta]
|
for f in sorted(save_path.iterdir()):
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
if f.suffix == ".pt":
|
||||||
meta = meta_list[0]
|
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
||||||
|
|
||||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
|
||||||
|
|
||||||
extra = None
|
|
||||||
extra_path = save_path / "extra.pt"
|
|
||||||
if extra_path.exists():
|
|
||||||
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta["epoch"],
|
epoch=meta.get("epoch", 0),
|
||||||
iteration=meta["iteration"],
|
iteration=meta.get("iteration", 0),
|
||||||
extra=extra,
|
extra=extra,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,10 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
# Message type for chat messages
|
|
||||||
type MessageType = Dict[str, Any]
|
type MessageType = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChatTemplate:
|
class ChatTemplate:
|
||||||
"""A chat template with Jinja2 rendering support.
|
"""A chat template with Jinja2 rendering support.
|
||||||
|
|
||||||
|
|
@ -15,23 +12,24 @@ class ChatTemplate:
|
||||||
name: Unique identifier for the template.
|
name: Unique identifier for the template.
|
||||||
template_str: Jinja2 template string.
|
template_str: Jinja2 template string.
|
||||||
description: Optional description.
|
description: Optional description.
|
||||||
default_variables: Optional dictionary of default variable values
|
default_variables: Optional dictionary of default variable values.
|
||||||
that will be passed to the template if not overridden during rendering.
|
|
||||||
special_tokens: Optional dictionary mapping token names to their string values.
|
special_tokens: Optional dictionary mapping token names to their string values.
|
||||||
These tokens are automatically added to the template variables.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
def __init__(
|
||||||
template_str: str
|
self,
|
||||||
description: str = ""
|
name: str = "",
|
||||||
default_variables: Dict[str, Any] = None
|
template_str: str = "",
|
||||||
special_tokens: Dict[str, str] = None
|
description: str = "",
|
||||||
|
default_variables: Optional[Dict[str, Any]] = None,
|
||||||
def __post_init__(self):
|
special_tokens: Optional[Dict[str, str]] = None,
|
||||||
if self.default_variables is None:
|
):
|
||||||
self.default_variables = {}
|
self.name = name
|
||||||
if self.special_tokens is None:
|
self.template_str = template_str
|
||||||
self.special_tokens = {}
|
self.description = description
|
||||||
|
self.default_variables = default_variables or {}
|
||||||
|
self.special_tokens = special_tokens or {}
|
||||||
|
self._compiled: Template = Template(template_str)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(
|
def from_string(
|
||||||
|
|
@ -43,7 +41,7 @@ class ChatTemplate:
|
||||||
) -> "ChatTemplate":
|
) -> "ChatTemplate":
|
||||||
"""Create a ChatTemplate instance directly from a template string."""
|
"""Create a ChatTemplate instance directly from a template string."""
|
||||||
return cls(
|
return cls(
|
||||||
name="", # empty name for ad‑hoc templates
|
name="",
|
||||||
template_str=template_str,
|
template_str=template_str,
|
||||||
description=description,
|
description=description,
|
||||||
default_variables=default_variables,
|
default_variables=default_variables,
|
||||||
|
|
@ -73,5 +71,4 @@ class ChatTemplate:
|
||||||
if system_prompt is not None:
|
if system_prompt is not None:
|
||||||
variables["system_prompt"] = system_prompt
|
variables["system_prompt"] = system_prompt
|
||||||
|
|
||||||
jinja_template = Template(self.template_str)
|
return self._compiled.render(**variables)
|
||||||
return jinja_template.render(**variables)
|
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,26 @@ class AutoTokenizer:
|
||||||
self.set_chat_template(config["chat_template"])
|
self.set_chat_template(config["chat_template"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
|
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
|
||||||
"""Load tokenizer from pretrained directory."""
|
"""Load tokenizer from pretrained directory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If tokenizer.json is missing.
|
||||||
|
RuntimeError: If tokenizer failed to initialize.
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
tokenizer_file = path / "tokenizer.json"
|
||||||
|
if not tokenizer_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Tokenizer file not found: {tokenizer_file}. "
|
||||||
|
"A valid tokenizer.json is required."
|
||||||
|
)
|
||||||
instance = cls(path)
|
instance = cls(path)
|
||||||
|
if instance._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load tokenizer from {path}. "
|
||||||
|
"The tokenizer.json may be corrupted or incompatible."
|
||||||
|
)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def save_pretrained(self, save_path: str):
|
def save_pretrained(self, save_path: str):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from astrai.trainer.optim import Muon
|
||||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
|
|
@ -9,6 +10,8 @@ from astrai.trainer.trainer import Trainer
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Main trainer
|
# Main trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
# Optimizer
|
||||||
|
"Muon",
|
||||||
# Strategy factory
|
# Strategy factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"BaseStrategy",
|
"BaseStrategy",
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,10 @@ def ctx_get_lr(ctx):
|
||||||
return ctx.optimizer.param_groups[-1]["lr"]
|
return ctx.optimizer.param_groups[-1]["lr"]
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_val_loss(ctx):
|
||||||
|
return ctx.val_loss
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_norm(ctx):
|
def ctx_get_grad_norm(ctx):
|
||||||
return grad_norm(ctx.model)
|
return grad_norm(ctx.model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
import torch
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||||
|
assert G.ndim == 2
|
||||||
|
X = G
|
||||||
|
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||||
|
X = X / (X.norm() + 1e-7) * scale
|
||||||
|
if steps == 0:
|
||||||
|
return X
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.T
|
||||||
|
B = A @ X
|
||||||
|
X = a * X + b * B + c * (A @ B)
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr: float = 2e-3,
|
||||||
|
momentum: float = 0.95,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
adamw_lr: float = None,
|
||||||
|
adamw_betas: tuple = (0.9, 0.95),
|
||||||
|
adamw_eps: float = 1e-8,
|
||||||
|
adamw_wd: float = 0.0,
|
||||||
|
):
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
momentum=momentum,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
nesterov=nesterov,
|
||||||
|
ns_steps=ns_steps,
|
||||||
|
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
||||||
|
adamw_betas=adamw_betas,
|
||||||
|
adamw_eps=adamw_eps,
|
||||||
|
adamw_wd=adamw_wd,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params_2d, params_1d = [], []
|
||||||
|
grads_2d, grads_1d = [], []
|
||||||
|
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
if p.grad.is_sparse:
|
||||||
|
raise RuntimeError("Muon does not support sparse gradients")
|
||||||
|
if p.ndim >= 2:
|
||||||
|
params_2d.append(p)
|
||||||
|
grads_2d.append(p.grad)
|
||||||
|
else:
|
||||||
|
params_1d.append(p)
|
||||||
|
grads_1d.append(p.grad)
|
||||||
|
|
||||||
|
if params_2d:
|
||||||
|
self._muon_update_foreach(params_2d, grads_2d, group)
|
||||||
|
if params_1d:
|
||||||
|
self._adamw_update_foreach(params_1d, grads_1d, group)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
||||||
|
lr = group["lr"]
|
||||||
|
momentum = group["momentum"]
|
||||||
|
wd = group["weight_decay"]
|
||||||
|
nesterov = group["nesterov"]
|
||||||
|
ns_steps = group["ns_steps"]
|
||||||
|
|
||||||
|
if wd != 0:
|
||||||
|
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
||||||
|
|
||||||
|
if nesterov:
|
||||||
|
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
||||||
|
|
||||||
|
bufs = []
|
||||||
|
for p, grad in zip(params_2d, grads_2d):
|
||||||
|
state = self.state[p]
|
||||||
|
if "momentum_buffer" not in state:
|
||||||
|
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||||
|
bufs.append(state["momentum_buffer"])
|
||||||
|
|
||||||
|
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
||||||
|
|
||||||
|
for p, buf in zip(params_2d, bufs):
|
||||||
|
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||||
|
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||||
|
p.add_(update, alpha=-lr * scale)
|
||||||
|
|
||||||
|
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
||||||
|
lr = group["adamw_lr"]
|
||||||
|
betas = group["adamw_betas"]
|
||||||
|
eps = group["adamw_eps"]
|
||||||
|
wd = group["adamw_wd"]
|
||||||
|
|
||||||
|
steps: list[int] = []
|
||||||
|
exp_avgs, exp_avg_sqs = [], []
|
||||||
|
has_state = []
|
||||||
|
for p in params_1d:
|
||||||
|
state = self.state[p]
|
||||||
|
if not state:
|
||||||
|
state["step"] = 0
|
||||||
|
state["exp_avg"] = torch.zeros_like(p)
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
|
has_state.append(False)
|
||||||
|
else:
|
||||||
|
has_state.append(True)
|
||||||
|
state["step"] += 1
|
||||||
|
steps.append(state["step"])
|
||||||
|
exp_avgs.append(state["exp_avg"])
|
||||||
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||||
|
|
||||||
|
beta1, beta2 = betas
|
||||||
|
|
||||||
|
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
||||||
|
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
||||||
|
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
||||||
|
|
||||||
|
bias_correction1 = [1 - beta1**s for s in steps]
|
||||||
|
bias_correction2 = [1 - beta2**s for s in steps]
|
||||||
|
|
||||||
|
if wd != 0:
|
||||||
|
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
||||||
|
|
||||||
|
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
||||||
|
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
||||||
|
denom = torch._foreach_sqrt(denom)
|
||||||
|
torch._foreach_add_(denom, eps)
|
||||||
|
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|
||||||
|
|
@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""Training strategy implementations with factory pattern."""
|
"""Training strategy implementations with factory pattern."""
|
||||||
|
|
||||||
import copy
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
|
|
@ -8,26 +7,14 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||||
"""Unwrap DDP wrapper if present to get the original model."""
|
"""Create a frozen reference model from model_fn + full state dict."""
|
||||||
if isinstance(model, DDP):
|
ref_model = model_fn()
|
||||||
return model.module
|
ref_model.load_state_dict(state_dict)
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
|
||||||
"""Create a reference model for DPO/GRPO training.
|
|
||||||
|
|
||||||
Handles DDP-wrapped models safely by unwrapping first,
|
|
||||||
then creating a deep copy with frozen gradients.
|
|
||||||
"""
|
|
||||||
original_model = unwrap_model(model)
|
|
||||||
ref_model = copy.deepcopy(original_model)
|
|
||||||
ref_model.requires_grad_(False)
|
ref_model.requires_grad_(False)
|
||||||
ref_model.eval()
|
ref_model.eval()
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
@ -81,6 +68,22 @@ def get_logprobs(
|
||||||
return token_logprobs * shifted_mask
|
return token_logprobs * shifted_mask
|
||||||
|
|
||||||
|
|
||||||
|
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
|
||||||
|
S = position_ids.size(1)
|
||||||
|
device = position_ids.device
|
||||||
|
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
|
||||||
|
doc_ids = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
|
||||||
|
boundaries.long().cumsum(dim=1),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
|
||||||
|
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
|
||||||
|
return (same_doc & causal).unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
"""Abstract base class for training strategies."""
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
|
|
@ -89,6 +92,8 @@ class BaseStrategy(ABC):
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.executor = kwargs.pop("executor", None)
|
||||||
|
self.model_fn = kwargs.pop("model_fn", None)
|
||||||
self.extra_kwargs = kwargs
|
self.extra_kwargs = kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -123,7 +128,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, strategy_cls: type) -> None:
|
def _validate_component(cls, strategy_cls: type):
|
||||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||||
|
|
@ -191,15 +196,19 @@ class SFTStrategy(BaseStrategy):
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids, loss_mask = (
|
input_ids, target_ids, position_ids, loss_mask = (
|
||||||
batch["input_ids"],
|
batch["input_ids"],
|
||||||
batch["target_ids"],
|
batch["target_ids"],
|
||||||
|
batch["position_ids"],
|
||||||
batch["loss_mask"],
|
batch["loss_mask"],
|
||||||
)
|
)
|
||||||
|
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
input_mask = make_doc_boundary_mask(position_ids)
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
logits = self.model(
|
||||||
|
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
|
||||||
|
)["logits"]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
|
|
@ -228,7 +237,9 @@ class DPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(
|
||||||
|
self.model_fn, self.executor.unwrap_model(model)
|
||||||
|
).to(device=self.device)
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
|
|
@ -282,7 +293,9 @@ class GRPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(
|
||||||
|
self.model_fn, self.executor.unwrap_model(model)
|
||||||
|
).to(device=self.device)
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
|
@ -292,8 +305,7 @@ class GRPOStrategy(BaseStrategy):
|
||||||
|
|
||||||
def sync_ref_model(self):
|
def sync_ref_model(self):
|
||||||
"""Copy current model weights to ref model."""
|
"""Copy current model weights to ref model."""
|
||||||
ref_state = self.model.state_dict()
|
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
||||||
self.ref_model.load_state_dict(ref_state)
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
self._step += 1
|
self._step += 1
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,21 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.parallel import only_on_rank
|
from astrai.parallel import only_on_rank
|
||||||
|
from astrai.parallel.setup import get_current_device, get_rank
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.metric_util import (
|
from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_max,
|
ctx_get_grad_max,
|
||||||
|
|
@ -20,9 +26,12 @@ from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_std,
|
ctx_get_grad_std,
|
||||||
ctx_get_loss,
|
ctx_get_loss,
|
||||||
ctx_get_lr,
|
ctx_get_lr,
|
||||||
|
ctx_get_val_loss,
|
||||||
)
|
)
|
||||||
from astrai.trainer.train_context import TrainContext
|
from astrai.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class TrainCallback(Protocol):
|
class TrainCallback(Protocol):
|
||||||
|
|
@ -42,18 +51,15 @@ class TrainCallback(Protocol):
|
||||||
def on_epoch_end(self, context: TrainContext):
|
def on_epoch_end(self, context: TrainContext):
|
||||||
"""Called at the end of each epoch."""
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
|
||||||
"""Called at the beginning of each step."""
|
|
||||||
|
|
||||||
def on_step_end(self, context: TrainContext):
|
|
||||||
"""Called at the end of each step."""
|
|
||||||
|
|
||||||
def on_batch_begin(self, context: TrainContext):
|
def on_batch_begin(self, context: TrainContext):
|
||||||
"""Called at the beginning of each batch."""
|
"""Called at the beginning of each batch."""
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
"""Called at the end of each batch."""
|
"""Called at the end of each batch."""
|
||||||
|
|
||||||
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
|
"""Called on every optimizer step (sync step only)."""
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
"""Called when an error occurs during training."""
|
"""Called when an error occurs during training."""
|
||||||
|
|
||||||
|
|
@ -79,53 +85,83 @@ class GradientClippingCallback(TrainCallback):
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_end(self, context: TrainContext):
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
_ = context
|
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("gradient_checkpointing")
|
||||||
|
class GradientCheckpointingCallback(TrainCallback):
|
||||||
|
"""
|
||||||
|
Activation checkpointing callback — trades compute for memory
|
||||||
|
by recomputing specified module activations during the backward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules: Module types to apply checkpointing to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, modules: Optional[List[type]] = None):
|
||||||
|
self.modules = tuple(modules) if modules else ()
|
||||||
|
|
||||||
|
def _enable(self, module: nn.Module):
|
||||||
|
if self.modules and isinstance(module, self.modules):
|
||||||
|
fn = module.forward
|
||||||
|
module._original_forward = fn
|
||||||
|
module.forward = lambda *a, **kw: torch_checkpoint(
|
||||||
|
fn, *a, use_reentrant=False, **kw
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _disable(module: nn.Module):
|
||||||
|
if hasattr(module, "_original_forward"):
|
||||||
|
module.forward = module._original_forward
|
||||||
|
del module._original_forward
|
||||||
|
|
||||||
|
def on_train_begin(self, context: TrainContext):
|
||||||
|
context.model.apply(self._enable)
|
||||||
|
logger.info("Gradient checkpointing enabled")
|
||||||
|
|
||||||
|
def on_train_end(self, context: TrainContext):
|
||||||
|
context.model.apply(self._disable)
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("checkpoint")
|
@CallbackFactory.register("checkpoint")
|
||||||
class CheckpointCallback(TrainCallback):
|
class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
extra_keys = ("optimizer", "scheduler")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||||
self.save_extra_fn = save_extra_fn
|
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
|
state_dict = context.executor.unwrap_model(context.model)
|
||||||
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
|
if get_rank() == 0:
|
||||||
save_path = os.path.join(
|
save_path = os.path.join(
|
||||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||||
)
|
)
|
||||||
state_dict = (
|
extra = self.save_extra_fn(context)
|
||||||
self.state_dict_fn(context.model)
|
|
||||||
if self.state_dict_fn
|
|
||||||
else context.model.state_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
iteration=context.iteration,
|
iteration=context.iteration,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
|
config=context.model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
|
|
@ -138,6 +174,15 @@ class CheckpointCallback(TrainCallback):
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_extra(context: TrainContext) -> dict:
|
||||||
|
extra = {}
|
||||||
|
for name in CheckpointCallback.extra_keys:
|
||||||
|
obj = getattr(context, name, None)
|
||||||
|
if obj:
|
||||||
|
extra[name] = obj.state_dict()
|
||||||
|
return extra
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("progress_bar")
|
@CallbackFactory.register("progress_bar")
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
|
|
@ -145,8 +190,12 @@ class ProgressBarCallback(TrainCallback):
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_epoch: int):
|
def __init__(
|
||||||
|
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
||||||
|
):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
|
self.log_interval = log_interval
|
||||||
|
self.file = file
|
||||||
self.progress_bar: tqdm = None
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -155,16 +204,18 @@ class ProgressBarCallback(TrainCallback):
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
|
file=self.file or sys.stdout,
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
self.progress_bar.set_postfix(
|
postfix = {
|
||||||
{
|
|
||||||
"loss": f"{context.loss:.4f}",
|
"loss": f"{context.loss:.4f}",
|
||||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||||
}
|
}
|
||||||
)
|
if context.val_loss > 0:
|
||||||
|
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
||||||
|
self.progress_bar.set_postfix(postfix)
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -196,6 +247,7 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
self._metric_funcs = {
|
self._metric_funcs = {
|
||||||
"loss": ctx_get_loss,
|
"loss": ctx_get_loss,
|
||||||
"lr": ctx_get_lr,
|
"lr": ctx_get_lr,
|
||||||
|
"val_loss": ctx_get_val_loss,
|
||||||
"grad_norm": ctx_get_grad_norm,
|
"grad_norm": ctx_get_grad_norm,
|
||||||
"grad_std": ctx_get_grad_std,
|
"grad_std": ctx_get_grad_std,
|
||||||
"grad_max": ctx_get_grad_max,
|
"grad_max": ctx_get_grad_max,
|
||||||
|
|
@ -206,7 +258,7 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
def _get_log_data(self, context: TrainContext):
|
||||||
return {
|
return {
|
||||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.iteration,
|
"iter": context.iteration,
|
||||||
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||||
|
|
@ -239,3 +291,43 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
|
|
||||||
def on_error(self, context):
|
def on_error(self, context):
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("validation")
|
||||||
|
class ValidationCallback(TrainCallback):
|
||||||
|
def _run_validation(self, context: TrainContext):
|
||||||
|
context.model.eval()
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in context.val_dataloader:
|
||||||
|
loss = context.strategy(batch)
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
|
if context.world_size > 1 and dist.is_initialized():
|
||||||
|
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
|
||||||
|
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
||||||
|
avg_loss = loss_tensor.item()
|
||||||
|
|
||||||
|
context.val_loss = avg_loss
|
||||||
|
context.model.train()
|
||||||
|
|
||||||
|
step_count = context.iteration // context.config.grad_accum_steps
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
|
if context.val_dataloader is None:
|
||||||
|
return
|
||||||
|
cfg = context.config
|
||||||
|
if cfg.val_step <= 0:
|
||||||
|
return
|
||||||
|
step_count = context.iteration // cfg.grad_accum_steps
|
||||||
|
if step_count % cfg.val_step == 0:
|
||||||
|
self._run_validation(context)
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,18 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Optional, Self
|
from pathlib import Path
|
||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.utils.data import DataLoader, random_split
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.dataset import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
|
from astrai.model.components.lora import inject_lora
|
||||||
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||||
|
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,13 +21,18 @@ class TrainContext:
|
||||||
model: nn.Module = field(default=None)
|
model: nn.Module = field(default=None)
|
||||||
strategy: BaseStrategy = field(default=None)
|
strategy: BaseStrategy = field(default=None)
|
||||||
dataloader: DataLoader = field(default=None)
|
dataloader: DataLoader = field(default=None)
|
||||||
optimizer: Optimizer = field(default=None)
|
optimizer: OptimizerProtocol = field(default=None)
|
||||||
scheduler: LRScheduler = field(default=None)
|
scheduler: SchedulerProtocol = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
|
config: TrainConfig = field(default=None)
|
||||||
|
model_config: dict = field(default_factory=dict)
|
||||||
|
executor: BaseExecutor = field(default=None)
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
|
val_dataloader: DataLoader = field(default=None)
|
||||||
|
val_loss: float = field(default=0.0)
|
||||||
|
|
||||||
world_size: int = field(default=1)
|
world_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
|
|
@ -35,67 +43,141 @@ class TrainContextBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TrainConfig,
|
config: TrainConfig,
|
||||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._checkpoint: Optional[Checkpoint] = None
|
self._resume_dir: Optional[str] = None
|
||||||
self._load_extra_fn = load_extra_fn
|
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
||||||
self._checkpoint = checkpoint
|
self._resume_dir = resume_dir
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
|
cfg = self.config
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
executor = ExecutorFactory.create(
|
||||||
|
cfg.parallel_mode,
|
||||||
|
grad_accum_steps=cfg.grad_accum_steps,
|
||||||
|
**cfg.executor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = cfg.model_fn()
|
||||||
|
model = model.to(device=device)
|
||||||
|
|
||||||
|
model_config = {}
|
||||||
|
if self._resume_dir:
|
||||||
|
config_path = Path(self._resume_dir) / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
model_config = load_json(config_path)
|
||||||
|
|
||||||
|
if not model_config and hasattr(model, "config"):
|
||||||
|
model_config = model.config.to_dict()
|
||||||
|
|
||||||
context = TrainContext(
|
context = TrainContext(
|
||||||
model=self.config.model,
|
model=model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
|
config=cfg,
|
||||||
|
model_config=model_config,
|
||||||
|
executor=executor,
|
||||||
)
|
)
|
||||||
|
|
||||||
device = get_current_device()
|
if self._resume_dir is not None:
|
||||||
context.model = context.model.to(device=device)
|
resume_path = Path(self._resume_dir)
|
||||||
|
if (resume_path / "meta.json").exists():
|
||||||
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
checkpoint = Checkpoint.load(self._resume_dir)
|
||||||
context.model = self.config.parallel_wrapper(context.model)
|
state_dict = checkpoint.state_dict
|
||||||
|
if checkpoint.config:
|
||||||
if self._checkpoint is not None:
|
context.model_config = checkpoint.config
|
||||||
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
|
||||||
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
|
||||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
|
||||||
context.checkpoint = self._checkpoint
|
|
||||||
else:
|
else:
|
||||||
context.checkpoint = Checkpoint(
|
checkpoint = None
|
||||||
state_dict=context.model.state_dict(),
|
state_dict = load_model_weights(self._resume_dir)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
if checkpoint is not None:
|
||||||
|
context.epoch = cfg.start_epoch
|
||||||
|
context.iteration = cfg.start_batch
|
||||||
|
context.checkpoint = checkpoint
|
||||||
|
|
||||||
|
if cfg.lora is not None:
|
||||||
|
inject_lora(
|
||||||
|
model,
|
||||||
|
r=cfg.lora.r,
|
||||||
|
alpha=cfg.lora.alpha,
|
||||||
|
target_modules=set(cfg.lora.target_modules),
|
||||||
)
|
)
|
||||||
|
|
||||||
context.optimizer = self.config.optimizer_fn(context.model)
|
context.optimizer = cfg.optimizer_fn(model)
|
||||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
train_dataset = cfg.dataset
|
||||||
self._load_extra_fn(self._checkpoint.extra, context)
|
val_dataset = cfg.val_dataset
|
||||||
|
|
||||||
cfg = self.config
|
if val_dataset is None and cfg.val_split is not None:
|
||||||
sampler_offset = context.iteration * cfg.batch_size
|
n_total = len(cfg.dataset)
|
||||||
|
n_val = max(1, int(n_total * cfg.val_split))
|
||||||
|
n_train = n_total - n_val
|
||||||
|
generator = torch.Generator().manual_seed(cfg.random_seed)
|
||||||
|
train_dataset, val_dataset = random_split(
|
||||||
|
cfg.dataset, [n_train, n_val], generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler_offset = context.iteration * cfg.batch_per_device
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
data_source=cfg.dataset,
|
data_source=train_dataset,
|
||||||
start_epoch=context.epoch,
|
start_epoch=context.epoch,
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=cfg.random_seed,
|
seed=cfg.random_seed,
|
||||||
)
|
)
|
||||||
context.dataloader = DataLoader(
|
context.dataloader = DataLoader(
|
||||||
cfg.dataset,
|
train_dataset,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_per_device,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
pin_memory=cfg.pin_memory,
|
pin_memory=cfg.pin_memory,
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if val_dataset is not None:
|
||||||
|
val_sampler = ResumableDistributedSampler(
|
||||||
|
data_source=val_dataset,
|
||||||
|
start_epoch=0,
|
||||||
|
start_iter=0,
|
||||||
|
seed=cfg.random_seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
context.val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=cfg.batch_per_device,
|
||||||
|
sampler=val_sampler,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=cfg.pin_memory,
|
||||||
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||||
|
executor.prepare(
|
||||||
|
model,
|
||||||
|
context.optimizer,
|
||||||
|
context.dataloader,
|
||||||
|
context.scheduler,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if context.checkpoint and context.checkpoint.extra:
|
||||||
|
extra = context.checkpoint.extra
|
||||||
|
for name in ("optimizer", "scheduler"):
|
||||||
|
if name in extra:
|
||||||
|
obj = getattr(context, name, None)
|
||||||
|
if obj is not None:
|
||||||
|
obj.load_state_dict(extra[name])
|
||||||
|
|
||||||
context.strategy = StrategyFactory.create(
|
context.strategy = StrategyFactory.create(
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=self.config.strategy,
|
train_type=cfg.strategy,
|
||||||
device=device,
|
device=device,
|
||||||
**self.config.extra_kwargs,
|
executor=executor,
|
||||||
|
model_fn=cfg.model_fn,
|
||||||
|
**cfg.extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from itertools import batched
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
from astrai.parallel.setup import spawn_parallel_fn
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
CallbackFactory,
|
CallbackFactory,
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
|
|
@ -26,17 +24,28 @@ class Trainer:
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
return [
|
callbacks = [
|
||||||
|
CallbackFactory.create(
|
||||||
|
"gradient_checkpointing",
|
||||||
|
modules=cfg.gradient_checkpointing_modules,
|
||||||
|
),
|
||||||
|
CallbackFactory.create(
|
||||||
|
"checkpoint",
|
||||||
|
cfg.ckpt_dir,
|
||||||
|
cfg.ckpt_interval,
|
||||||
|
),
|
||||||
|
CallbackFactory.create(
|
||||||
|
"metric_logger",
|
||||||
|
log_dir=cfg.log_dir,
|
||||||
|
save_interval=cfg.ckpt_interval,
|
||||||
|
log_interval=cfg.log_interval,
|
||||||
|
metrics=cfg.metrics,
|
||||||
|
),
|
||||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||||
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
|
||||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
|
CallbackFactory.create("validation"),
|
||||||
]
|
]
|
||||||
|
return callbacks
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
|
||||||
return (
|
|
||||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
|
@ -44,45 +53,33 @@ class Trainer:
|
||||||
if method:
|
if method:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
||||||
config = self.train_config
|
context = (
|
||||||
spawn_parallel_fn(
|
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
||||||
self._train_impl,
|
|
||||||
backend=config.backend,
|
|
||||||
world_size=config.nprocs,
|
|
||||||
master_addr=config.master_addr,
|
|
||||||
master_port=config.master_port,
|
|
||||||
device_type=config.device_type,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
)
|
)
|
||||||
|
executor = context.executor
|
||||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
|
||||||
context = self._build_context(checkpoint)
|
|
||||||
self._call_callbacks("on_train_begin", context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
|
||||||
|
|
||||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, context.config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for steps in batched(context.dataloader, accumulation_steps):
|
for batch in context.dataloader:
|
||||||
self._call_callbacks("on_step_begin", context)
|
|
||||||
|
|
||||||
step_batch_nums = len(steps)
|
|
||||||
for batch in steps:
|
|
||||||
self._call_callbacks("on_batch_begin", context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
|
|
||||||
|
with executor.accumulate(context.model):
|
||||||
loss = context.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
|
executor.backward(stand_loss)
|
||||||
context.iteration += 1
|
context.iteration += 1
|
||||||
|
|
||||||
stand_loss = loss / step_batch_nums
|
|
||||||
stand_loss.backward()
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
self._call_callbacks("on_step_end", context)
|
if executor.sync_gradients:
|
||||||
|
self._call_callbacks("on_optimizer_step", context)
|
||||||
context.optimizer.step()
|
context.optimizer.step()
|
||||||
context.optimizer.zero_grad()
|
context.optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
@ -92,8 +89,21 @@ class Trainer:
|
||||||
self._call_callbacks("on_epoch_end", context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
logger.error("Training failed: %s", str(e), exc_info=True)
|
||||||
self._call_callbacks("on_error", context)
|
self._call_callbacks("on_error", context)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks("on_train_end", context)
|
self._call_callbacks("on_train_end", context)
|
||||||
|
|
||||||
|
def train(self, resume_dir: Optional[str] = None):
|
||||||
|
cfg = self.train_config
|
||||||
|
spawn_parallel_fn(
|
||||||
|
self._trainer_loop,
|
||||||
|
backend=cfg.backend,
|
||||||
|
world_size=cfg.nprocs,
|
||||||
|
master_addr=cfg.master_addr,
|
||||||
|
master_port=cfg.master_port,
|
||||||
|
device_type=cfg.device_type,
|
||||||
|
start_method=cfg.start_method,
|
||||||
|
resume_dir=resume_dir,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
services:
|
services:
|
||||||
server:
|
server:
|
||||||
build: .
|
build:
|
||||||
image: astrai:latest
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
user: "${UID:-1000}:${GID:-1000}"
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./params:/app/params:ro
|
- ./params:/app/params:ro
|
||||||
- ./checkpoints:/app/checkpoints
|
|
||||||
command: python -m scripts.tools.server --port 8000 --device cuda
|
command: python -m scripts.tools.server --port 8000 --device cuda
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
|
|
@ -25,13 +26,14 @@ services:
|
||||||
|
|
||||||
server-cpu:
|
server-cpu:
|
||||||
profiles: [cpu]
|
profiles: [cpu]
|
||||||
build: .
|
build:
|
||||||
image: astrai:latest
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
user: "${UID:-1000}:${GID:-1000}"
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./params:/app/params:ro
|
- ./params:/app/params:ro
|
||||||
- ./checkpoints:/app/checkpoints
|
|
||||||
command: python -m scripts.tools.server --port 8000 --device cpu
|
command: python -m scripts.tools.server --port 8000 --device cpu
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
def generate_text():
|
def generate_text():
|
||||||
# Load model from pretrained
|
|
||||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
@ -22,16 +21,15 @@ def generate_text():
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
response = engine.generate(
|
for token in engine.generate(
|
||||||
prompt=query,
|
prompt=query,
|
||||||
stream=False,
|
stream=True,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
)
|
):
|
||||||
|
print(token, end="", flush=True)
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ NC='\033[0m' # No Color
|
||||||
IMAGE_NAME="astrai"
|
IMAGE_NAME="astrai"
|
||||||
IMAGE_TAG="latest"
|
IMAGE_TAG="latest"
|
||||||
REGISTRY=""
|
REGISTRY=""
|
||||||
|
CONTAINER_ID=""
|
||||||
|
|
||||||
# Print colored messages
|
# Print colored messages
|
||||||
print_info() {
|
print_info() {
|
||||||
|
|
@ -175,6 +176,10 @@ main() {
|
||||||
PORT="$2"
|
PORT="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
|
--container)
|
||||||
|
CONTAINER_ID="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
--gpu)
|
--gpu)
|
||||||
GPU=true
|
GPU=true
|
||||||
shift
|
shift
|
||||||
|
|
@ -197,6 +202,7 @@ main() {
|
||||||
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
|
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
|
||||||
echo " --context PATH Build context (default: .)"
|
echo " --context PATH Build context (default: .)"
|
||||||
echo " --port PORT Port for run (default: 8000)"
|
echo " --port PORT Port for run (default: 8000)"
|
||||||
|
echo " --container ID Container ID for logs"
|
||||||
echo " --gpu Enable GPU support"
|
echo " --gpu Enable GPU support"
|
||||||
echo " --help Show this help message"
|
echo " --help Show this help message"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
@ -205,6 +211,7 @@ main() {
|
||||||
echo " $0 build --tag v1.0.0"
|
echo " $0 build --tag v1.0.0"
|
||||||
echo " $0 run --port 8080"
|
echo " $0 run --port 8080"
|
||||||
echo " $0 run --gpu"
|
echo " $0 run --gpu"
|
||||||
|
echo " $0 logs --container abc123"
|
||||||
echo " $0 push --registry ghcr.io/username"
|
echo " $0 push --registry ghcr.io/username"
|
||||||
exit 0
|
exit 0
|
||||||
;;
|
;;
|
||||||
|
|
@ -237,7 +244,7 @@ main() {
|
||||||
show_info
|
show_info
|
||||||
;;
|
;;
|
||||||
logs)
|
logs)
|
||||||
show_logs "$2"
|
show_logs "$CONTAINER_ID"
|
||||||
;;
|
;;
|
||||||
"")
|
"")
|
||||||
print_error "No command specified. Use --help for usage"
|
print_error "No command specified. Use --help for usage"
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
"""Benchmark Transformer with KVCache"""
|
"""Benchmark AutoRegressiveLM with KVCache"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import AutoRegressiveLMConfig
|
||||||
from astrai.inference import KVCache
|
from astrai.inference import KVCache
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -21,7 +21,7 @@ class BenchmarkResult:
|
||||||
class GenerationBenchmark:
|
class GenerationBenchmark:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: AutoRegressiveLMConfig,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
page_size: int = 128,
|
page_size: int = 128,
|
||||||
|
|
@ -29,7 +29,7 @@ class GenerationBenchmark:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
head_dim = config.dim // config.n_heads
|
head_dim = config.dim // config.n_heads
|
||||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||||
|
|
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = ModelConfig(
|
config = AutoRegressiveLMConfig(
|
||||||
vocab_size=10000,
|
vocab_size=10000,
|
||||||
dim=1536,
|
dim=1536,
|
||||||
n_heads=24,
|
n_heads=24,
|
||||||
|
|
@ -230,7 +230,7 @@ if __name__ == "__main__":
|
||||||
benchmark = GenerationBenchmark(config)
|
benchmark = GenerationBenchmark(config)
|
||||||
|
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Running Transformer Generation Benchmark (KVCache)")
|
print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
prefill_result = benchmark.run_prefill_benchmark(
|
prefill_result = benchmark.run_prefill_benchmark(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,336 @@
|
||||||
|
"""HumanEval code generation benchmark.
|
||||||
|
|
||||||
|
Generates n completions per problem, extracts function bodies, executes
|
||||||
|
against hidden tests, and computes pass@k.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
python scripts/tools/evaluate_humaneval.py --param_path ./params \
|
||||||
|
--data_path HumanEval.jsonl.gz --output results.json \
|
||||||
|
--num_samples 200 --temperature 0.8 --max_tokens 512
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from math import prod
|
||||||
|
from multiprocessing import Process, Queue
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from astrai.inference import InferenceEngine
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
HUMANEVAL_URL = (
|
||||||
|
"https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
_STOP_SEQUENCES = [
|
||||||
|
"\nclass ",
|
||||||
|
"\ndef ",
|
||||||
|
"\n# ",
|
||||||
|
"\nif __name__",
|
||||||
|
"\nprint(",
|
||||||
|
"\n\n\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _download_humaneval(data_path: str):
|
||||||
|
if os.path.exists(data_path):
|
||||||
|
return
|
||||||
|
import gzip
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
|
||||||
|
print(f"Downloading HumanEval from {HUMANEVAL_URL} ...")
|
||||||
|
tmp = data_path + ".tmp"
|
||||||
|
urllib.request.urlretrieve(HUMANEVAL_URL, tmp)
|
||||||
|
with gzip.open(tmp, "rb") as f_in:
|
||||||
|
with open(data_path, "wb") as f_out:
|
||||||
|
f_out.write(f_in.read())
|
||||||
|
os.remove(tmp)
|
||||||
|
print(f" saved to {data_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_problems(data_path: str) -> List[dict]:
|
||||||
|
problems = []
|
||||||
|
with open(data_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
problems.append(json.loads(line))
|
||||||
|
return problems
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_function_body(code: str, entry_point: str) -> Optional[str]:
|
||||||
|
"""Extract the function body from a completion."""
|
||||||
|
pattern = rf"def\s+{re.escape(entry_point)}\b[^:]*:"
|
||||||
|
match = re.search(pattern, code)
|
||||||
|
if not match:
|
||||||
|
# Use the full code as-is if we can't find the function
|
||||||
|
return code
|
||||||
|
|
||||||
|
body_start = match.end()
|
||||||
|
lines = code[body_start:].split("\n")
|
||||||
|
body_lines = []
|
||||||
|
started = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
stripped = line.rstrip()
|
||||||
|
if not stripped and not started:
|
||||||
|
continue
|
||||||
|
if not stripped and started:
|
||||||
|
body_lines.append("")
|
||||||
|
continue
|
||||||
|
if not started:
|
||||||
|
started = True
|
||||||
|
if stripped.lstrip() == stripped and started:
|
||||||
|
break
|
||||||
|
body_lines.append(stripped)
|
||||||
|
|
||||||
|
body = "\n".join(body_lines)
|
||||||
|
if not body.strip():
|
||||||
|
return None
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_stop_sequences(text: str) -> str:
|
||||||
|
for stop in _STOP_SEQUENCES:
|
||||||
|
idx = text.find(stop)
|
||||||
|
if idx != -1:
|
||||||
|
text = text[:idx]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_code(problem: dict, completion: str, timeout: float = 3.0) -> bool:
|
||||||
|
"""Run the completion against hidden tests in a subprocess."""
|
||||||
|
|
||||||
|
def _worker(queue, full_code):
|
||||||
|
try:
|
||||||
|
namespace = {}
|
||||||
|
exec(full_code, namespace)
|
||||||
|
check = namespace.get("check")
|
||||||
|
if check is None:
|
||||||
|
queue.put(False)
|
||||||
|
return
|
||||||
|
check(namespace.get(problem["entry_point"]))
|
||||||
|
queue.put(True)
|
||||||
|
except Exception:
|
||||||
|
queue.put(False)
|
||||||
|
|
||||||
|
full_code = problem["prompt"] + completion + "\n" + problem["test"]
|
||||||
|
|
||||||
|
queue: Queue = Queue()
|
||||||
|
proc = Process(target=_worker, args=(queue, full_code))
|
||||||
|
proc.start()
|
||||||
|
proc.join(timeout)
|
||||||
|
|
||||||
|
if proc.is_alive():
|
||||||
|
proc.terminate()
|
||||||
|
proc.join()
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
return queue.get_nowait()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _pass_at_k(n: int, c: int, k: int) -> float:
|
||||||
|
"""Unbiased estimator of pass@k."""
|
||||||
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
|
return 1.0 - float(prod(1.0 - k / np.arange(n - c + 1, n + 1)))
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate(completions: List[str]) -> List[str]:
|
||||||
|
seen = set()
|
||||||
|
unique = []
|
||||||
|
for c in completions:
|
||||||
|
if c not in seen:
|
||||||
|
seen.add(c)
|
||||||
|
unique.append(c)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
engine: InferenceEngine,
|
||||||
|
prompt: str,
|
||||||
|
num_samples: int,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
batch_size: int,
|
||||||
|
) -> List[str]:
|
||||||
|
batches = [prompt] * min(batch_size, num_samples)
|
||||||
|
completions = []
|
||||||
|
remaining = num_samples
|
||||||
|
|
||||||
|
while remaining > 0:
|
||||||
|
current = min(batch_size, remaining)
|
||||||
|
batch_prompts = batches[:current]
|
||||||
|
outputs = engine.generate(
|
||||||
|
prompt=batch_prompts,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
if isinstance(outputs, str):
|
||||||
|
outputs = [outputs]
|
||||||
|
completions.extend(outputs)
|
||||||
|
remaining -= current
|
||||||
|
|
||||||
|
return _deduplicate(completions)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
engine: InferenceEngine,
|
||||||
|
problems: List[dict],
|
||||||
|
num_samples: int,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
batch_size: int,
|
||||||
|
k_values: Tuple[int, ...] = (1, 10, 100),
|
||||||
|
) -> Dict:
|
||||||
|
results = {}
|
||||||
|
all_pass_at_k = {k: [] for k in k_values}
|
||||||
|
|
||||||
|
for problem in tqdm.tqdm(problems, desc="HumanEval", unit="problem"):
|
||||||
|
task_id = problem["task_id"]
|
||||||
|
prompt = problem["prompt"]
|
||||||
|
entry_point = problem["entry_point"]
|
||||||
|
|
||||||
|
raw_completions = _generate(
|
||||||
|
engine,
|
||||||
|
prompt,
|
||||||
|
num_samples,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
top_k,
|
||||||
|
batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
completions = []
|
||||||
|
for raw in raw_completions:
|
||||||
|
trimmed = _trim_stop_sequences(raw)
|
||||||
|
body = _extract_function_body(trimmed, entry_point)
|
||||||
|
if body:
|
||||||
|
completions.append(body)
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
for comp in completions:
|
||||||
|
if _execute_code(problem, comp):
|
||||||
|
passed += 1
|
||||||
|
|
||||||
|
n = len(completions)
|
||||||
|
c = passed
|
||||||
|
result = {"task_id": task_id, "n": n, "passed": c}
|
||||||
|
for k in k_values:
|
||||||
|
result[f"pass@{k}"] = round(_pass_at_k(n, c, k), 4)
|
||||||
|
all_pass_at_k[k].append(_pass_at_k(n, c, k))
|
||||||
|
results[task_id] = result
|
||||||
|
|
||||||
|
summary = {}
|
||||||
|
for k in k_values:
|
||||||
|
vals = all_pass_at_k[k]
|
||||||
|
summary[f"pass@{k}"] = round(float(np.mean(vals)), 4)
|
||||||
|
results["_summary"] = summary
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="HumanEval benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--param_path", type=str, default="./params", help="Model directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_path",
|
||||||
|
type=str,
|
||||||
|
default="./humaneval/HumanEval.jsonl",
|
||||||
|
help="HumanEval JSONL file (auto-download if missing)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_samples",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="Completions per problem",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_tokens", type=int, default=512, help="Max generation tokens"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature", type=float, default=0.8, help="Sampling temperature"
|
||||||
|
)
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
|
||||||
|
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Inference batch size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--problems",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Specific problem indices (0-based)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
_download_humaneval(args.data_path)
|
||||||
|
problems = _load_problems(args.data_path)
|
||||||
|
if args.problems:
|
||||||
|
problems = [problems[i] for i in args.problems if i < len(problems)]
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(args.param_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
||||||
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
engine = InferenceEngine(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = evaluate(
|
||||||
|
engine=engine,
|
||||||
|
problems=problems,
|
||||||
|
num_samples=args.num_samples,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
top_k=args.top_k,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
k_values=(1, 10, 100),
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = results.pop("_summary")
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
for k, v in summary.items():
|
||||||
|
print(f" {k}: {v:.2%}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
results["_summary"] = summary
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"Results saved to {args.output}")
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,319 @@
|
||||||
|
"""MMLU evaluation via log-likelihood ranking."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
|
||||||
|
MMLU_SUBJECTS = [
|
||||||
|
"abstract_algebra",
|
||||||
|
"anatomy",
|
||||||
|
"astronomy",
|
||||||
|
"business_ethics",
|
||||||
|
"clinical_knowledge",
|
||||||
|
"college_biology",
|
||||||
|
"college_chemistry",
|
||||||
|
"college_computer_science",
|
||||||
|
"college_mathematics",
|
||||||
|
"college_medicine",
|
||||||
|
"college_physics",
|
||||||
|
"computer_security",
|
||||||
|
"conceptual_physics",
|
||||||
|
"econometrics",
|
||||||
|
"electrical_engineering",
|
||||||
|
"elementary_mathematics",
|
||||||
|
"formal_logic",
|
||||||
|
"global_facts",
|
||||||
|
"high_school_biology",
|
||||||
|
"high_school_chemistry",
|
||||||
|
"high_school_computer_science",
|
||||||
|
"high_school_european_history",
|
||||||
|
"high_school_geography",
|
||||||
|
"high_school_government_and_politics",
|
||||||
|
"high_school_macroeconomics",
|
||||||
|
"high_school_mathematics",
|
||||||
|
"high_school_microeconomics",
|
||||||
|
"high_school_physics",
|
||||||
|
"high_school_psychology",
|
||||||
|
"high_school_statistics",
|
||||||
|
"high_school_us_history",
|
||||||
|
"high_school_world_history",
|
||||||
|
"human_aging",
|
||||||
|
"human_sexuality",
|
||||||
|
"international_law",
|
||||||
|
"jurisprudence",
|
||||||
|
"logical_fallacies",
|
||||||
|
"machine_learning",
|
||||||
|
"management",
|
||||||
|
"marketing",
|
||||||
|
"medical_genetics",
|
||||||
|
"miscellaneous",
|
||||||
|
"moral_disputes",
|
||||||
|
"moral_scenarios",
|
||||||
|
"nutrition",
|
||||||
|
"philosophy",
|
||||||
|
"prehistory",
|
||||||
|
"professional_accounting",
|
||||||
|
"professional_law",
|
||||||
|
"professional_medicine",
|
||||||
|
"professional_psychology",
|
||||||
|
"public_relations",
|
||||||
|
"security_studies",
|
||||||
|
"sociology",
|
||||||
|
"us_foreign_policy",
|
||||||
|
"virology",
|
||||||
|
"world_religions",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _download_and_extract(url: str, data_dir: str):
|
||||||
|
tar_path = os.path.join(data_dir, "data.tar")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
print(f"Downloading MMLU data from {url}...")
|
||||||
|
resp = requests.get(url, stream=True, timeout=300)
|
||||||
|
resp.raise_for_status()
|
||||||
|
total = int(resp.headers.get("content-length", 0))
|
||||||
|
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
|
||||||
|
with open(tar_path, "wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
|
print("Extracting...")
|
||||||
|
with tarfile.open(tar_path, "r") as tf:
|
||||||
|
tf.extractall(data_dir)
|
||||||
|
os.remove(tar_path)
|
||||||
|
|
||||||
|
|
||||||
|
def download_mmlu(data_dir: str):
|
||||||
|
_download_and_extract(MMLU_URL, data_dir)
|
||||||
|
src = os.path.join(data_dir, "data")
|
||||||
|
if os.path.exists(src):
|
||||||
|
for item in os.listdir(src):
|
||||||
|
src_item = os.path.join(src, item)
|
||||||
|
dst_item = os.path.join(data_dir, item)
|
||||||
|
if os.path.exists(dst_item):
|
||||||
|
if os.path.isdir(dst_item):
|
||||||
|
shutil.rmtree(dst_item)
|
||||||
|
else:
|
||||||
|
os.remove(dst_item)
|
||||||
|
os.rename(src_item, dst_item)
|
||||||
|
os.rmdir(src)
|
||||||
|
print(f"MMLU data saved to {data_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_prefix(text: str, prefix: str) -> str:
|
||||||
|
if text.startswith(prefix):
|
||||||
|
return text[len(prefix) :].strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def load_csv(path: str) -> list[dict]:
|
||||||
|
data = []
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for row in csv.reader(f):
|
||||||
|
if len(row) < 6:
|
||||||
|
continue
|
||||||
|
if row[0].strip().lower() == "question":
|
||||||
|
continue
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"question": row[0].strip(),
|
||||||
|
"A": _strip_prefix(row[1].strip(), "A)"),
|
||||||
|
"B": _strip_prefix(row[2].strip(), "B)"),
|
||||||
|
"C": _strip_prefix(row[3].strip(), "C)"),
|
||||||
|
"D": _strip_prefix(row[4].strip(), "D)"),
|
||||||
|
"answer": row[5].strip(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
|
||||||
|
) -> str:
|
||||||
|
prompt = ""
|
||||||
|
if n_shot > 0 and dev_data:
|
||||||
|
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
|
||||||
|
for item in dev_data[:n_shot]:
|
||||||
|
prompt += f"Question: {item['question']}\n"
|
||||||
|
for k in ("A", "B", "C", "D"):
|
||||||
|
prompt += f"{k}. {item[k]}\n"
|
||||||
|
prompt += f"Answer: {item['answer']}\n\n"
|
||||||
|
prompt += f"Question: {question}\n"
|
||||||
|
for k in ("A", "B", "C", "D"):
|
||||||
|
prompt += f"{k}. {choices[k]}\n"
|
||||||
|
prompt += "Answer:"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def apply_chat(
|
||||||
|
tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None
|
||||||
|
) -> str:
|
||||||
|
"""Wrap raw MMLU prompt in the model's chat template format.
|
||||||
|
|
||||||
|
For few-shot, prepend example Q&A pairs as a second user/assistant exchange.
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
if n_shot > 0 and dev_data:
|
||||||
|
for item in dev_data[:n_shot]:
|
||||||
|
q = f"Question: {item['question']}\n"
|
||||||
|
for k in ("A", "B", "C", "D"):
|
||||||
|
q += f"{k}. {item[k]}\n"
|
||||||
|
q += "Answer:"
|
||||||
|
messages.append({"role": "user", "content": q})
|
||||||
|
messages.append({"role": "assistant", "content": item["answer"]})
|
||||||
|
messages.append({"role": "user", "content": raw_prompt})
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def choice_logprob(
|
||||||
|
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
|
||||||
|
) -> float:
|
||||||
|
choice_text = choice_letter
|
||||||
|
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
|
||||||
|
input_ids = context_ids + choice_ids
|
||||||
|
max_len = model.config.max_len
|
||||||
|
if len(input_ids) > max_len:
|
||||||
|
overflow = len(input_ids) - max_len
|
||||||
|
input_ids = input_ids[overflow:]
|
||||||
|
ctx_len = len(input_ids) - len(choice_ids)
|
||||||
|
else:
|
||||||
|
ctx_len = len(context_ids)
|
||||||
|
|
||||||
|
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
|
||||||
|
with torch.inference_mode():
|
||||||
|
logits = model(input_tensor)["logits"][0]
|
||||||
|
|
||||||
|
score = 0.0
|
||||||
|
for i, tid in enumerate(choice_ids):
|
||||||
|
pos = ctx_len - 1 + i
|
||||||
|
if pos >= len(logits):
|
||||||
|
break
|
||||||
|
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_subject(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
subject: str,
|
||||||
|
test_data: list[dict],
|
||||||
|
dev_data: list[dict] | None,
|
||||||
|
device: str,
|
||||||
|
n_shot: int,
|
||||||
|
) -> tuple[float, int, int]:
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
|
||||||
|
raw_prompt = build_prompt(
|
||||||
|
item["question"], item, subject, n_shot, dev_data or []
|
||||||
|
)
|
||||||
|
context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or [])
|
||||||
|
context_ids = tokenizer.encode(context)
|
||||||
|
scores = {
|
||||||
|
c: choice_logprob(model, tokenizer, context_ids, c, device)
|
||||||
|
for c in ("A", "B", "C", "D")
|
||||||
|
}
|
||||||
|
if max(scores, key=scores.get) == item["answer"]:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
return correct / total, correct, total
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
||||||
|
parser.add_argument(
|
||||||
|
"--param_path", type=str, default="./params", help="Model directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
||||||
|
)
|
||||||
|
parser.add_argument("--download", action="store_true", help="Download MMLU data")
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", type=str, help="Output JSON path")
|
||||||
|
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Device",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
default="bfloat16" if torch.cuda.is_available() else "float32",
|
||||||
|
help="Torch dtype",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.download or not os.path.exists(args.data_dir):
|
||||||
|
download_mmlu(args.data_dir)
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(args.param_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
||||||
|
device = args.device
|
||||||
|
dtype = getattr(torch, args.dtype)
|
||||||
|
model.to(device=device, dtype=dtype)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
subjects = args.subjects or MMLU_SUBJECTS
|
||||||
|
results = {}
|
||||||
|
total_correct = 0
|
||||||
|
total_questions = 0
|
||||||
|
|
||||||
|
for subject in subjects:
|
||||||
|
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
|
||||||
|
test_path = os.path.join(
|
||||||
|
args.data_dir, args.split, f"{subject}_{args.split}.csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(test_path):
|
||||||
|
print(f" Skipping {subject}: test file not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
|
||||||
|
test_data = load_csv(test_path)
|
||||||
|
|
||||||
|
acc, corr, tot = evaluate_subject(
|
||||||
|
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
|
||||||
|
)
|
||||||
|
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
|
||||||
|
total_correct += corr
|
||||||
|
total_questions += tot
|
||||||
|
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
|
||||||
|
|
||||||
|
overall = total_correct / total_questions if total_questions else 0
|
||||||
|
print(f"\n{'=' * 70}")
|
||||||
|
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
|
||||||
|
results["_overall"] = {
|
||||||
|
"accuracy": round(overall, 4),
|
||||||
|
"correct": total_correct,
|
||||||
|
"total": total_questions,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print(f"Results saved to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||||
):
|
):
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model = AutoModel.from_pretrained(model_dir)
|
model = AutoModel.from_pretrained(param_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
with open(input_file, "r", encoding="utf-8") as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
|
|
@ -44,8 +44,8 @@ def process_file(
|
||||||
|
|
||||||
for seq in batch_encoded:
|
for seq in batch_encoded:
|
||||||
pad_len = max_len - len(seq)
|
pad_len = max_len - len(seq)
|
||||||
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
padded_seq = seq + [tokenizer.pad_id] * pad_len
|
||||||
mask = [False] * pad_len + [True] * len(seq)
|
mask = [True] * len(seq) + [False] * pad_len
|
||||||
padded_ids.append(padded_seq)
|
padded_ids.append(padded_seq)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
|
||||||
|
|
@ -88,7 +88,7 @@ def process_file(
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
"--param_path", type=str, required=True, help="Path to the model directory."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input_file", type=str, required=True, help="Path to the input file."
|
"--input_file", type=str, required=True, help="Path to the input file."
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", "-c", required=True, help="Path to pipeline config JSON"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_path",
|
||||||
|
default="params",
|
||||||
|
help="Path to tokenizer directory (default: params)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = PipelineConfig.from_json(args.config)
|
||||||
|
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=args.inputs,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -18,7 +18,7 @@ def main():
|
||||||
"--reload", action="store_true", help="Enable auto-reload for development"
|
"--reload", action="store_true", help="Enable auto-reload for development"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--param-path",
|
"--param_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to model parameters (default: project_root/params)",
|
help="Path to model parameters (default: project_root/params)",
|
||||||
|
|
|
||||||
|
|
@ -2,22 +2,19 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
from astrai.config import ModelConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import Transformer
|
from astrai.model import AutoRegressiveLM
|
||||||
from astrai.parallel import get_rank
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_type",
|
"--train_type",
|
||||||
|
|
@ -42,18 +39,20 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accumulation_steps",
|
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad_accum_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of iterations between each optimizer step.",
|
help="Number of iterations between each optimizer step.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warmup_steps",
|
"--warmup_ratio",
|
||||||
type=int,
|
type=float,
|
||||||
default=1000,
|
default=0.05,
|
||||||
help="Number of warmup steps for LR scheduler.",
|
help="Fraction of total steps used for LR warmup.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||||
|
|
@ -68,13 +67,13 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--adamw_beta1",
|
"--adamw_beta1",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.9,
|
default=0.9,
|
||||||
help="Beta values for AdamW optimizer.",
|
help="Beta1 for AdamW optimizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adamw_beta2",
|
"--adamw_beta2",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.95,
|
default=0.95,
|
||||||
help="Beta values for AdamW optimizer.",
|
help="Beta2 for AdamW optimizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adamw_weight_decay",
|
"--adamw_weight_decay",
|
||||||
|
|
@ -114,9 +113,15 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_smoothing",
|
"--label_smoothing",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.1,
|
default=0.05,
|
||||||
help="cross_entropy function label smoothing parameter",
|
help="cross_entropy function label smoothing parameter",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_checkpointing",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=False,
|
||||||
|
help="Enable activation checkpointing for DecoderBlock modules.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt_interval",
|
"--ckpt_interval",
|
||||||
|
|
@ -130,6 +135,36 @@ def parse_args() -> argparse.Namespace:
|
||||||
default="checkpoint",
|
default="checkpoint",
|
||||||
help="Directory to save checkpoints.",
|
help="Directory to save checkpoints.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val_split",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Ratio to split from training dataset for validation (e.g. 0.05).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val_step",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of optimizer steps between validation runs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metrics",
|
||||||
|
nargs="*",
|
||||||
|
default=["loss", "lr"],
|
||||||
|
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_dir",
|
||||||
|
type=str,
|
||||||
|
default="checkpoint/logs",
|
||||||
|
help="Directory for metric logs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_interval",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of batch iterations between metric logs.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grpo_sync_interval",
|
"--grpo_sync_interval",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -143,31 +178,53 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--start_batch", type=int, default=0, help="Start batch for training."
|
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_addr",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Master node address for distributed training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_port",
|
||||||
|
type=str,
|
||||||
|
default="29500",
|
||||||
|
help="Master node port for distributed training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="nccl",
|
||||||
|
help="Distributed training backend.",
|
||||||
|
)
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--parallel_mode",
|
||||||
|
type=str,
|
||||||
|
default="none",
|
||||||
|
choices=["none", "ddp", "fsdp"],
|
||||||
|
help="Parallel training strategy (none, ddp, fsdp).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_method",
|
||||||
|
type=str,
|
||||||
|
default="spawn",
|
||||||
|
choices=["spawn", "fork", "forkserver"],
|
||||||
|
help="Multiprocessing start method.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def create_model(config):
|
||||||
local_rank = get_rank()
|
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||||
ddp_model = DDP(
|
|
||||||
model,
|
|
||||||
device_ids=[local_rank],
|
|
||||||
output_device=local_rank,
|
|
||||||
static_graph=True,
|
|
||||||
find_unused_parameters=False,
|
|
||||||
gradient_as_bucket_view=True,
|
|
||||||
broadcast_buffers=False,
|
|
||||||
)
|
|
||||||
return ddp_model
|
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -177,8 +234,21 @@ def create_scheduler(
|
||||||
return SchedulerFactory.create(optimizer, **kwargs)
|
return SchedulerFactory.create(optimizer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
def compute_total_steps(
|
||||||
return model.module.state_dict()
|
dataset_len: int,
|
||||||
|
n_epoch: int,
|
||||||
|
batch_per_device: int,
|
||||||
|
nprocs: int,
|
||||||
|
grad_accum_steps: int,
|
||||||
|
) -> int:
|
||||||
|
|
||||||
|
def ceil_div(a: int, b: int) -> int:
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
samples_per_replica = ceil_div(dataset_len, nprocs)
|
||||||
|
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
|
||||||
|
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
|
||||||
|
return total_steps
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
|
@ -187,13 +257,18 @@ def train(
|
||||||
data_root_path: str,
|
data_root_path: str,
|
||||||
max_lr: float,
|
max_lr: float,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
batch_size: int,
|
batch_per_device: int,
|
||||||
start_epoch: int,
|
start_epoch: int,
|
||||||
start_batch: int,
|
start_batch: int,
|
||||||
accumulation_steps: int,
|
grad_accum_steps: int,
|
||||||
warmup_steps: int,
|
warmup_ratio: float,
|
||||||
ckpt_interval: int,
|
ckpt_interval: int,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
|
val_split: float,
|
||||||
|
val_step: int,
|
||||||
|
metrics: list[str],
|
||||||
|
log_dir: str,
|
||||||
|
log_interval: int,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
grpo_clip_eps: float,
|
grpo_clip_eps: float,
|
||||||
grpo_kl_coef: float,
|
grpo_kl_coef: float,
|
||||||
|
|
@ -207,36 +282,31 @@ def train(
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
|
gradient_checkpointing: bool,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
|
parallel_mode: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
|
backend: str,
|
||||||
|
master_addr: str,
|
||||||
|
master_port: str,
|
||||||
|
start_method: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
if nprocs > 1 and parallel_mode == "none":
|
||||||
|
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config = ModelConfig()
|
|
||||||
config_path = os.path.join(param_path, "config.json")
|
config_path = os.path.join(param_path, "config.json")
|
||||||
if os.path.exists(config_path):
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
config.load(config_path)
|
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = config.max_len
|
window_size = config.max_len
|
||||||
|
|
||||||
# Create bare Transformer (for training, no tokenizer needed)
|
|
||||||
model = Transformer(config)
|
|
||||||
|
|
||||||
# Load weights if available
|
|
||||||
weights_path = os.path.join(param_path, "model.safetensors")
|
|
||||||
if os.path.exists(weights_path):
|
|
||||||
state_dict = st.load_file(weights_path)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
model = model.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"dpo_beta": dpo_beta,
|
"beta": dpo_beta,
|
||||||
"label_smoothing": label_smoothing,
|
"label_smoothing": label_smoothing,
|
||||||
"clip_eps": grpo_clip_eps,
|
"clip_eps": grpo_clip_eps,
|
||||||
"kl_coef": grpo_kl_coef,
|
"kl_coef": grpo_kl_coef,
|
||||||
|
|
@ -244,6 +314,12 @@ def train(
|
||||||
"sync_interval": grpo_sync_interval,
|
"sync_interval": grpo_sync_interval,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
executor_kwargs = {
|
||||||
|
"gradient_as_bucket_view": True,
|
||||||
|
"broadcast_buffers": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_fn = partial(create_model, config)
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
|
|
@ -260,42 +336,58 @@ def train(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
total_steps = compute_total_steps(
|
||||||
|
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
|
||||||
|
)
|
||||||
|
warmup_steps = int(warmup_ratio * total_steps)
|
||||||
|
|
||||||
scheduler_fn = partial(
|
scheduler_fn = partial(
|
||||||
create_scheduler,
|
create_scheduler,
|
||||||
**{
|
**{
|
||||||
"schedule_type": "cosine",
|
"schedule_type": "cosine",
|
||||||
"warmup_steps": warmup_steps,
|
"warmup_steps": min(warmup_steps, total_steps),
|
||||||
"lr_decay_steps": total_steps - warmup_steps,
|
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=model,
|
model_fn=model_fn,
|
||||||
strategy=train_type,
|
strategy=train_type,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_per_device=batch_per_device,
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
start_batch=start_batch,
|
start_batch=start_batch,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
accumulation_steps=accumulation_steps,
|
grad_accum_steps=grad_accum_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
parallel_wrapper=ddp_wrap,
|
backend=backend,
|
||||||
state_dict_fn=prepare_checkpoint,
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
parallel_mode=parallel_mode,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
|
start_method=start_method,
|
||||||
|
val_split=val_split,
|
||||||
|
val_step=val_step,
|
||||||
|
metrics=metrics,
|
||||||
|
log_dir=log_dir,
|
||||||
|
log_interval=log_interval,
|
||||||
|
gradient_checkpointing_modules=grad_ckpt_modules,
|
||||||
|
executor_kwargs=executor_kwargs,
|
||||||
extra_kwargs=strategy_kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train()
|
trainer.train(resume_dir=param_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import torch
|
||||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -104,19 +104,19 @@ def test_tokenizer():
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def test_model():
|
def test_model():
|
||||||
"""Session-scoped small Transformer model, created once."""
|
"""Session-scoped small AutoRegressiveLM model, created once."""
|
||||||
config = ModelConfig(
|
config = AutoRegressiveLMConfig(
|
||||||
vocab_size=1000,
|
vocab_size=1000,
|
||||||
dim=16,
|
dim=8,
|
||||||
n_heads=4,
|
n_heads=2,
|
||||||
n_kv_heads=2,
|
n_kv_heads=1,
|
||||||
dim_ffn=32,
|
dim_ffn=16,
|
||||||
max_len=1024,
|
max_len=64,
|
||||||
n_layers=4,
|
n_layers=2,
|
||||||
norm_eps=1e-5,
|
norm_eps=1e-5,
|
||||||
)
|
)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = Transformer(config).to(device=device)
|
model = AutoRegressiveLM(config).to(device=device)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 16,
|
"dim": 8,
|
||||||
"n_heads": 4,
|
"n_heads": 2,
|
||||||
"n_kv_heads": 2,
|
"n_kv_heads": 1,
|
||||||
"dim_ffn": 32,
|
"dim_ffn": 16,
|
||||||
"max_len": 1024,
|
"max_len": 64,
|
||||||
"n_layers": 4,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5,
|
"norm_eps": 1e-5,
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,202 @@
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
_SPECIAL_TOKENS_CONFIG = {
|
||||||
|
"bos_token": "<|begin_of_sentence|>",
|
||||||
|
"eos_token": "<|end_of_sentence|>",
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
"im_start": "<|im_start|>",
|
||||||
|
"im_end": "<|im_end|>",
|
||||||
|
}
|
||||||
|
|
||||||
|
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
||||||
|
|
||||||
|
_CHAT_TEMPLATE = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if message['role'] == 'system' %}"
|
||||||
|
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||||
|
"{% elif message['role'] == 'user' %}"
|
||||||
|
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||||
|
"{% elif message['role'] == 'assistant' %}"
|
||||||
|
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
||||||
|
|
||||||
|
_INSTRUCTION_SECTIONS = [
|
||||||
|
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
||||||
|
{"field": "response", "action": "train"},
|
||||||
|
]
|
||||||
|
|
||||||
|
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
||||||
|
|
||||||
|
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chat_tokenizer():
|
||||||
|
tok = Tokenizer(models.BPE())
|
||||||
|
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||||
|
tr = trainers.BpeTrainer(
|
||||||
|
vocab_size=512,
|
||||||
|
min_frequency=1,
|
||||||
|
special_tokens=_SPECIAL_TOKENS,
|
||||||
|
)
|
||||||
|
train_data = [
|
||||||
|
"hello world",
|
||||||
|
"Hi there!",
|
||||||
|
"You are helpful.",
|
||||||
|
"What is 2+2?",
|
||||||
|
"Tell me a story about dragons and knights.",
|
||||||
|
"Sure, here is a tale.",
|
||||||
|
"Translate to French: Hello",
|
||||||
|
"Bonjour",
|
||||||
|
"Artificial Intelligence is a field of computer science.",
|
||||||
|
"system",
|
||||||
|
"user",
|
||||||
|
"assistant",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
*[chr(i) for i in range(32, 127)],
|
||||||
|
]
|
||||||
|
tok.train_from_iterator(train_data, tr)
|
||||||
|
|
||||||
|
auto_tok = AutoTokenizer()
|
||||||
|
auto_tok._tokenizer = tok
|
||||||
|
auto_tok._special_token_map = {
|
||||||
|
"bos_token": "<|begin_of_sentence|>",
|
||||||
|
"eos_token": "<|end_of_sentence|>",
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
||||||
|
return auto_tok
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def chat_tokenizer():
|
||||||
|
return _build_chat_tokenizer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
d = tempfile.mkdtemp()
|
||||||
|
yield d
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(d, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def make_chat_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_instruction_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(
|
||||||
|
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_dpo_chat_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sources={
|
||||||
|
"chosen": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "chosen", "action": "$role", "template": True}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"rejected": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "rejected", "action": "$role", "template": True}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
mask={"user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_grpo_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sources={
|
||||||
|
"prompts": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "prompt", "action": "mask", "template": True}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"sections": _GRPO_RESPONSE_SECTIONS,
|
||||||
|
"list_field": True,
|
||||||
|
"mask_key": "masks",
|
||||||
|
},
|
||||||
|
"rewards": {
|
||||||
|
"sections": [{"field": "rewards", "action": "value"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
mask={"user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_grpo_no_template_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sources={
|
||||||
|
"prompts": {
|
||||||
|
"sections": [
|
||||||
|
{
|
||||||
|
"field": "prompt",
|
||||||
|
"action": "mask",
|
||||||
|
"add_special_tokens": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"sections": _GRPO_RESPONSE_SECTIONS,
|
||||||
|
"list_field": True,
|
||||||
|
"mask_key": "masks",
|
||||||
|
},
|
||||||
|
"rewards": {
|
||||||
|
"sections": [{"field": "rewards", "action": "value"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
mask={"user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -35,6 +36,30 @@ def test_single_process():
|
||||||
assert loaded_checkpoint.iteration == 30
|
assert loaded_checkpoint.iteration == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_with_extra():
|
||||||
|
model = torch.nn.Linear(10, 5)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
extra = {
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": {"last_epoch": 5},
|
||||||
|
}
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
checkpoint.save(tmpdir)
|
||||||
|
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||||
|
|
||||||
|
loaded = Checkpoint.load(tmpdir)
|
||||||
|
assert loaded.extra["scheduler"]["last_epoch"] == 5
|
||||||
|
assert "state" in loaded.extra["optimizer"]
|
||||||
|
|
||||||
|
|
||||||
def simple_training():
|
def simple_training():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -7,12 +6,11 @@ import torch
|
||||||
|
|
||||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseSegmentFetcher,
|
H5Store,
|
||||||
H5Storage,
|
StoreFactory,
|
||||||
MultiSegmentFetcher,
|
|
||||||
create_storage,
|
|
||||||
detect_format,
|
detect_format,
|
||||||
load_json,
|
load_bin,
|
||||||
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -100,6 +98,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
|
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir, "sft_data", dummy_data)
|
save_h5(test_dir, "sft_data", dummy_data)
|
||||||
|
|
@ -157,111 +156,6 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
assert len(dataset) > len(default_stride_dataset)
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
||||||
|
|
||||||
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tokenizer_fn(tokenizer):
|
|
||||||
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
|
||||||
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_seq_dataset_from_json_text(base_test_env):
|
|
||||||
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_text")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
"hello world this is a test sentence for tokenizer",
|
|
||||||
"another sentence with different words and tokens",
|
|
||||||
"machine learning is fascinating and powerful",
|
|
||||||
]
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "seq_data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="seq",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=16,
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count > 0
|
|
||||||
assert "sequence" in dataset.keys
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert "input_ids" in item
|
|
||||||
assert "target_ids" in item
|
|
||||||
assert item["input_ids"].shape[0] == 16
|
|
||||||
|
|
||||||
|
|
||||||
def test_sft_dataset_from_json_text(base_test_env):
|
|
||||||
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_sft")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
"user asks a question about the weather",
|
|
||||||
"assistant provides a helpful response to the user",
|
|
||||||
]
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "sft_data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(
|
|
||||||
{"sequence": texts, "loss_mask": texts},
|
|
||||||
f,
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="sft",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=16,
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert "loss_mask" in item
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_storage_explicit_tokenizer(base_test_env):
|
|
||||||
"""Test explicit JSON storage with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_explicit")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
token_count = len(tokenizer_fn(texts[0]))
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="seq",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=32,
|
|
||||||
storage_type="json",
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == token_count
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_count_property(base_test_env):
|
def test_dataset_count_property(base_test_env):
|
||||||
"""Test the count property returns correct raw token count"""
|
"""Test the count property returns correct raw token count"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
@ -318,37 +212,29 @@ def test_unloaded_dataset_len():
|
||||||
assert len(dataset) == 0
|
assert len(dataset) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_base_segment_fetcher_empty():
|
def test_store_unloaded_len():
|
||||||
"""BaseSegmentFetcher with empty segments list"""
|
"""Unloaded Store has __len__ == 0"""
|
||||||
fetcher = BaseSegmentFetcher([])
|
store = H5Store()
|
||||||
assert len(fetcher) == 0
|
assert len(store) == 0
|
||||||
with pytest.raises(ValueError, match="out of bounds"):
|
assert store.keys == []
|
||||||
fetcher.fetch_data(0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_segment_fetcher_begin_equals_end(base_test_env):
|
def test_store_fetch_begin_equals_end(base_test_env):
|
||||||
"""fetch_data with begin == end returns empty tensor"""
|
"""Store.fetch with begin == end returns empty tensor"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
||||||
save_h5(test_dir, "empty_fetch", dummy)
|
save_h5(test_dir, "empty_fetch", dummy)
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
||||||
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
|
result = dataset.storage.fetch(10, 10, "sequence")
|
||||||
result = fetcher.fetch_data(10, 10)
|
|
||||||
assert result.numel() == 0
|
assert result.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_multi_segment_fetcher_empty_dict():
|
def test_store_fetch_before_load():
|
||||||
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
|
"""Store.fetch before load raises RuntimeError"""
|
||||||
fetcher = MultiSegmentFetcher({})
|
store = H5Store()
|
||||||
assert len(fetcher) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_storage_fetch_before_load():
|
|
||||||
"""BaseStorage.fetch before load raises RuntimeError"""
|
|
||||||
storage = H5Storage()
|
|
||||||
with pytest.raises(RuntimeError, match="not loaded"):
|
with pytest.raises(RuntimeError, match="not loaded"):
|
||||||
storage.fetch(0, 10, "sequence")
|
store.fetch(0, 10, "sequence")
|
||||||
|
|
||||||
|
|
||||||
def test_detect_format_nonexistent_path():
|
def test_detect_format_nonexistent_path():
|
||||||
|
|
@ -367,54 +253,192 @@ def test_detect_format_unsupported_file(base_test_env):
|
||||||
detect_format(path)
|
detect_format(path)
|
||||||
|
|
||||||
|
|
||||||
def test_create_storage_invalid_type():
|
def test_create_store_invalid_type():
|
||||||
"""create_storage raises ValueError for unknown type"""
|
"""StoreFactory.create raises ValueError for unknown type"""
|
||||||
with pytest.raises(ValueError, match="Unknown storage type"):
|
with pytest.raises(ValueError, match="Unknown component"):
|
||||||
create_storage("parquet")
|
StoreFactory.create("parquet")
|
||||||
|
|
||||||
|
|
||||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
def test_store_multi_segment_concat(base_test_env):
|
||||||
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
||||||
|
import os
|
||||||
|
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
data_dir = os.path.join(test_dir, "json_pretok")
|
data_dir = os.path.join(test_dir, "multi_seg")
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == 10
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
|
||||||
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_json_skips_config_file(base_test_env):
|
|
||||||
"""load_json skips scalar-value config files"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
|
||||||
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
|
||||||
|
|
||||||
with open(os.path.join(test_dir, "data.json"), "w") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
|
||||||
|
|
||||||
result = load_json(test_dir)
|
|
||||||
assert "sequence" in result
|
|
||||||
assert "vocab_size" not in result
|
|
||||||
assert len(result["sequence"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_segment_fetcher_multi_segment():
|
|
||||||
"""fetch_data across multiple segment boundaries"""
|
|
||||||
segs = [
|
segs = [
|
||||||
torch.tensor([1, 2, 3]),
|
torch.tensor([1, 2, 3]),
|
||||||
torch.tensor([4, 5, 6, 7]),
|
torch.tensor([4, 5, 6, 7]),
|
||||||
torch.tensor([8, 9]),
|
torch.tensor([8, 9]),
|
||||||
]
|
]
|
||||||
fetcher = BaseSegmentFetcher(segs)
|
save_h5(data_dir, "data", {"sequence": segs})
|
||||||
assert len(fetcher) == 9
|
|
||||||
result = fetcher.fetch_data(2, 7)
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(data_dir)
|
||||||
|
assert len(store) == 9
|
||||||
|
result = store.fetch(2, 7, "sequence")
|
||||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_bin_roundtrip(base_test_env):
|
||||||
|
"""save_bin + load_bin roundtrip preserves data"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
|
||||||
|
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
result = load_bin(test_dir)
|
||||||
|
|
||||||
|
assert "sequence" in result
|
||||||
|
assert "loss_mask" in result
|
||||||
|
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
|
||||||
|
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmap_store_load_and_fetch(base_test_env):
|
||||||
|
"""MmapStore loads bin data and fetches correctly"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
|
||||||
|
store = StoreFactory.create("bin")
|
||||||
|
store.load(test_dir)
|
||||||
|
assert len(store) == 200
|
||||||
|
assert "sequence" in store.keys
|
||||||
|
|
||||||
|
result = store.fetch(10, 20, "sequence")
|
||||||
|
assert result.tolist() == data["sequence"][0][10:20].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmap_dataset_load(base_test_env):
|
||||||
|
"""DatasetFactory.load auto-detects bin format"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == 200
|
||||||
|
assert dataset[0]["input_ids"].shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_empty_key():
|
||||||
|
"""_normalize with empty tensor list does not crash"""
|
||||||
|
store = H5Store()
|
||||||
|
store._normalize({"sequence": []})
|
||||||
|
assert len(store) == 0
|
||||||
|
assert store.keys == ["sequence"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_mixed_empty_key():
|
||||||
|
"""_normalize with empty + non-empty keys returns min=0"""
|
||||||
|
store = H5Store()
|
||||||
|
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
|
||||||
|
assert len(store) == 0
|
||||||
|
assert set(store.keys) == {"sequence", "loss_mask"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_dataset_dtype(base_test_env):
|
||||||
|
"""GRPODataset returns correct dtypes"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
seq_len = 100
|
||||||
|
data = {
|
||||||
|
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||||
|
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||||
|
"masks": [torch.ones(seq_len, dtype=torch.int32)],
|
||||||
|
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
|
||||||
|
}
|
||||||
|
save_h5(test_dir, "grpo_dtype", data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
|
||||||
|
item = dataset[0]
|
||||||
|
|
||||||
|
assert item["prompts"].dtype == torch.long
|
||||||
|
assert item["responses"].dtype == torch.long
|
||||||
|
assert item["masks"].dtype == torch.bool
|
||||||
|
assert item["rewards"].dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_dataset_load(base_test_env):
|
||||||
|
"""GRPODataset loads and returns correct keys"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
seq_len = 200
|
||||||
|
data = {
|
||||||
|
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||||
|
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||||
|
"masks": [torch.ones(seq_len, dtype=torch.int64)],
|
||||||
|
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
|
||||||
|
}
|
||||||
|
save_h5(test_dir, "grpo_test", data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
item = dataset[0]
|
||||||
|
assert "prompts" in item
|
||||||
|
assert "responses" in item
|
||||||
|
assert "masks" in item
|
||||||
|
assert "rewards" in item
|
||||||
|
assert item["prompts"].shape[0] == 64
|
||||||
|
assert item["responses"].shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_format_bin_dir(base_test_env):
|
||||||
|
"""detect_format returns 'bin' for directory with .bin + meta.json"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
|
||||||
|
assert detect_format(test_dir) == "bin"
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_fetch_multi_key(base_test_env):
|
||||||
|
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(
|
||||||
|
test_dir,
|
||||||
|
"multi_key",
|
||||||
|
{
|
||||||
|
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
|
||||||
|
"loss_mask": [torch.ones(100, dtype=torch.int64)],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(test_dir)
|
||||||
|
result = store.fetch(10, 20, ["sequence", "loss_mask"])
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["sequence"].shape[0] == 10
|
||||||
|
assert result["loss_mask"].shape[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_fetch_out_of_bounds(base_test_env):
|
||||||
|
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
|
||||||
|
|
||||||
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(test_dir)
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(-1, 10, "sequence")
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(0, 51, "sequence")
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(50, 50, "sequence")
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_load_explicit_storage_type(base_test_env):
|
||||||
|
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == 200
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,396 @@
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
OutputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing.builder import (
|
||||||
|
MaskBuilderFactory,
|
||||||
|
SectionedMaskBuilder,
|
||||||
|
)
|
||||||
|
from tests.data.conftest import (
|
||||||
|
_CHAT_SECTIONS,
|
||||||
|
_INSTRUCTION_SECTIONS,
|
||||||
|
_TEXT_SECTIONS,
|
||||||
|
make_chat_config,
|
||||||
|
make_dpo_chat_config,
|
||||||
|
make_grpo_config,
|
||||||
|
make_instruction_config,
|
||||||
|
make_text_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_simple(chat_tokenizer):
|
||||||
|
config = make_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hello."},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "sequence" in result
|
||||||
|
assert "loss_mask" in result
|
||||||
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
|
||||||
|
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
|
||||||
|
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
||||||
|
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
||||||
|
|
||||||
|
total = len(result["sequence"])
|
||||||
|
trained = sum(result["loss_mask"])
|
||||||
|
assert trained > 0
|
||||||
|
assert trained < total
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_mask_only_assistant(chat_tokenizer):
|
||||||
|
config = make_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
ids = result["sequence"]
|
||||||
|
assert len(ids) == len(mask)
|
||||||
|
|
||||||
|
trained = [i for i, m in enumerate(mask) if m == 1]
|
||||||
|
masked = [i for i, m in enumerate(mask) if m == 0]
|
||||||
|
assert len(trained) > 0
|
||||||
|
assert len(masked) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_all_masked(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert sum(result["loss_mask"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_all_trained(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={},
|
||||||
|
mask_default="train",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_empty_messages(chat_tokenizer):
|
||||||
|
config = make_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
||||||
|
assert builder.build({}, config, chat_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_domain_extraction(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(domain_key="source"),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello"},
|
||||||
|
],
|
||||||
|
"source": "wiki",
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result["domain"] == "wiki"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_truncation(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=10),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Tell me a very long story about dragons and knights and magic.",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert len(result["sequence"]) <= 10
|
||||||
|
assert len(result["loss_mask"]) == len(result["sequence"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_instruction_basic(test_tokenizer):
|
||||||
|
config = make_instruction_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_instruction_prompt_masked(test_tokenizer):
|
||||||
|
config = make_instruction_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "hello", "response": "world"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
ids = result["sequence"]
|
||||||
|
|
||||||
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||||
|
p_len = min(len(prompt_ids), len(ids))
|
||||||
|
assert all(m == 0 for m in mask[:p_len])
|
||||||
|
if p_len < len(ids):
|
||||||
|
assert all(m == 1 for m in mask[p_len:])
|
||||||
|
|
||||||
|
|
||||||
|
def test_instruction_train_on_prompt(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sections=[
|
||||||
|
{"field": "prompt", "action": "train", "add_special_tokens": True},
|
||||||
|
{"field": "response", "action": "mask"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "hello", "response": "world"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
ids = result["sequence"]
|
||||||
|
|
||||||
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||||
|
p_len = min(len(prompt_ids), len(ids))
|
||||||
|
assert all(m == 1 for m in mask[:p_len])
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_basic(test_tokenizer):
|
||||||
|
config = make_text_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "Hello world. This is a test document."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "sequence" in result
|
||||||
|
assert len(result["sequence"]) > 0
|
||||||
|
assert "loss_mask" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_empty(test_tokenizer):
|
||||||
|
config = make_text_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
||||||
|
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_too_short(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(min_chars=100),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_truncation(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "This is a very long text that should be truncated"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert len(result["sequence"]) <= 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_chat(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
assert sum(result["loss_mask"]) > 0
|
||||||
|
assert 0 in result["loss_mask"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_instruction(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
assert mask[0] == 0
|
||||||
|
assert mask[-1] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_text(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "Hello world, this is a test."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "loss_mask" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_text_too_short(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_registered():
|
||||||
|
names = MaskBuilderFactory._registry.list_names()
|
||||||
|
assert "sectioned" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_create():
|
||||||
|
builder = MaskBuilderFactory.create("sectioned")
|
||||||
|
assert isinstance(builder, SectionedMaskBuilder)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_chat_basic(chat_tokenizer):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "5"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "chosen" in result
|
||||||
|
assert "rejected" in result
|
||||||
|
assert "chosen_mask" in result
|
||||||
|
assert "rejected_mask" in result
|
||||||
|
assert "domain" in result
|
||||||
|
assert len(result["chosen"]) == len(result["chosen_mask"])
|
||||||
|
assert len(result["rejected"]) == len(result["rejected_mask"])
|
||||||
|
assert sum(result["chosen_mask"]) > 0
|
||||||
|
assert sum(result["rejected_mask"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_chosen_only_trained(chat_tokenizer):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello"},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Go away"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert 0 in result["chosen_mask"]
|
||||||
|
assert 1 in result["chosen_mask"]
|
||||||
|
assert 0 in result["rejected_mask"]
|
||||||
|
assert 1 in result["rejected_mask"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_missing_field_is_none(chat_tokenizer):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_basic(chat_tokenizer):
|
||||||
|
config = make_grpo_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"prompt": [{"role": "user", "content": "What is 2+2?"}],
|
||||||
|
"responses": ["4", "The answer is four", "Four", "2+2=4"],
|
||||||
|
"rewards": [1.0, 0.5, 0.8, 0.2],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "prompts" in result
|
||||||
|
assert "responses" in result
|
||||||
|
assert "masks" in result
|
||||||
|
assert "rewards" in result
|
||||||
|
assert len(result["responses"]) == len(result["masks"])
|
||||||
|
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_response_tokens_all_trained(chat_tokenizer):
|
||||||
|
config = make_grpo_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"prompt": [{"role": "user", "content": "Q"}],
|
||||||
|
"responses": ["A", "B"],
|
||||||
|
"rewards": [0.8, 0.2],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
masks = result["masks"]
|
||||||
|
assert all(m == 1 for m in masks)
|
||||||
|
assert len(masks) == len(result["responses"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_single_reward(chat_tokenizer):
|
||||||
|
config = make_grpo_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"prompt": [{"role": "user", "content": "Q"}],
|
||||||
|
"responses": ["A"],
|
||||||
|
"rewards": 0.9,
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result["rewards"] == [0.9]
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
)
|
||||||
|
from tests.data.conftest import (
|
||||||
|
_INSTRUCTION_SECTIONS,
|
||||||
|
_TEXT_SECTIONS,
|
||||||
|
make_dpo_chat_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_values():
|
||||||
|
config = PipelineConfig()
|
||||||
|
assert config.version == 1
|
||||||
|
assert config.mask == {}
|
||||||
|
assert config.mask_default == "mask"
|
||||||
|
assert config.preprocessing.max_seq_len == 2048
|
||||||
|
assert config.output.storage_format == "bin"
|
||||||
|
assert config.input.sections is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_flat():
|
||||||
|
data = {
|
||||||
|
"version": 1,
|
||||||
|
"input": {
|
||||||
|
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
||||||
|
},
|
||||||
|
"mask": {"system": "mask", "assistant": "train"},
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {"max_seq_len": 1024},
|
||||||
|
"output": {"storage_format": "h5"},
|
||||||
|
}
|
||||||
|
config = PipelineConfig.from_dict(data)
|
||||||
|
assert config.input.sections == [
|
||||||
|
{"field": "messages", "action": "$role", "template": True}
|
||||||
|
]
|
||||||
|
assert config.mask == {"system": "mask", "assistant": "train"}
|
||||||
|
assert config.preprocessing.max_seq_len == 1024
|
||||||
|
assert config.output.storage_format == "h5"
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_dict_roundtrip():
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
)
|
||||||
|
d = config.to_dict()
|
||||||
|
config2 = PipelineConfig.from_dict(d)
|
||||||
|
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
||||||
|
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_json_from_json(temp_dir):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
mask={"text": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
)
|
||||||
|
path = os.path.join(temp_dir, "config.json")
|
||||||
|
config.to_json(path)
|
||||||
|
loaded = PipelineConfig.from_json(path)
|
||||||
|
assert loaded.input.sections == _TEXT_SECTIONS
|
||||||
|
assert loaded.mask == {"text": "train"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_config_roundtrip(temp_dir):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
path = os.path.join(temp_dir, "config.json")
|
||||||
|
config.to_json(path)
|
||||||
|
loaded = PipelineConfig.from_json(path)
|
||||||
|
assert loaded.input.sources is not None
|
||||||
|
assert "chosen" in loaded.input.sources
|
||||||
|
assert "rejected" in loaded.input.sources
|
||||||
|
assert loaded.input.sections is None
|
||||||
|
|
@ -0,0 +1,349 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
OutputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||||
|
from tests.data.conftest import (
|
||||||
|
_CHAT_SECTIONS,
|
||||||
|
_CHAT_TEMPLATE,
|
||||||
|
_INSTRUCTION_SECTIONS,
|
||||||
|
_SPECIAL_TOKENS_CONFIG,
|
||||||
|
_TEXT_SECTIONS,
|
||||||
|
make_dpo_chat_config,
|
||||||
|
make_grpo_no_template_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_by_length():
|
||||||
|
assert filter_by_length("hello world", min_len=5)
|
||||||
|
assert not filter_by_length("hi", min_len=5)
|
||||||
|
assert not filter_by_length("x" * 100, max_len=50)
|
||||||
|
assert filter_by_length("just right", min_len=5, max_len=20)
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||||
|
"chat_template": _CHAT_TEMPLATE,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hi."},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(storage_format="bin", domain_key=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "sequence" in meta
|
||||||
|
assert "loss_mask" in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "int32"
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_text_pipeline(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"text": "Another document for testing purposes with sufficient length to be processed."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
|
||||||
|
output=OutputConfig(storage_format="bin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "sequence" in meta
|
||||||
|
assert "loss_mask" not in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": "Tell me a joke",
|
||||||
|
"response": "Why did the chicken cross the road?",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": "What is AI?",
|
||||||
|
"response": "Artificial Intelligence is a field of computer science.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(storage_format="bin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "sequence" in meta
|
||||||
|
assert "loss_mask" in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "int32"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dtype_override(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "data.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "bool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_pipeline(temp_dir, chat_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||||
|
"chat_template": _CHAT_TEMPLATE,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "Hi."},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "Hi."},
|
||||||
|
{"role": "assistant", "content": "Go away."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=make_dpo_chat_config(),
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "chosen" in meta
|
||||||
|
assert "rejected" in meta
|
||||||
|
assert "chosen_mask" in meta
|
||||||
|
assert "rejected_mask" in meta
|
||||||
|
assert "sequence" not in meta
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_pipeline(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": "Question?",
|
||||||
|
"responses": ["Answer A", "Answer B"],
|
||||||
|
"rewards": [0.8, 0.3],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=make_grpo_no_template_config(),
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "prompts" in meta
|
||||||
|
assert "responses" in meta
|
||||||
|
assert "masks" in meta
|
||||||
|
assert "rewards" in meta
|
||||||
|
assert "sequence" not in meta
|
||||||
|
|
@ -5,21 +5,22 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from astrai.inference import app
|
from astrai.inference import get_app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
"""Provide a test client for the FastAPI app."""
|
"""Provide a test client for the FastAPI app."""
|
||||||
app.state.server_config = {
|
_app = get_app()
|
||||||
|
_app.state.server_config = {
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"dtype": "bfloat16",
|
"dtype": "bfloat16",
|
||||||
"param_path": None,
|
"param_path": None,
|
||||||
"max_batch_size": 1,
|
"max_batch_size": 1,
|
||||||
"_test": True,
|
"_test": True,
|
||||||
}
|
}
|
||||||
app.state.engine = None
|
_app.state.engine = None
|
||||||
return TestClient(app)
|
return TestClient(_app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -49,5 +50,5 @@ def mock_engine():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def loaded_model(client, mock_engine):
|
def loaded_model(client, mock_engine):
|
||||||
"""Simulate that the engine is loaded."""
|
"""Simulate that the engine is loaded."""
|
||||||
app.state.engine = mock_engine
|
get_app().state.engine = mock_engine
|
||||||
return mock_engine
|
return mock_engine
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,286 @@
|
||||||
|
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||||
|
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
|
||||||
|
from astrai.inference.engine import GenerationRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ctx(**kwargs):
|
||||||
|
defaults = {
|
||||||
|
"resp_id": "test-123",
|
||||||
|
"created": 1000,
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return GenContext(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _sse_payloads(events):
|
||||||
|
payloads = []
|
||||||
|
for chunk in events:
|
||||||
|
for line in chunk.strip().split("\n"):
|
||||||
|
if line.startswith("data: "):
|
||||||
|
try:
|
||||||
|
payloads.append(json.loads(line[6:]))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return payloads
|
||||||
|
|
||||||
|
|
||||||
|
class TestStopChecker:
|
||||||
|
def test_check_finds_match(self):
|
||||||
|
sc = StopChecker(["stop", "end"])
|
||||||
|
assert sc.check("hello stop world") == "stop"
|
||||||
|
|
||||||
|
def test_check_returns_none_when_no_match(self):
|
||||||
|
sc = StopChecker(["stop"])
|
||||||
|
assert sc.check("hello world") is None
|
||||||
|
|
||||||
|
def test_check_empty_sequences(self):
|
||||||
|
sc = StopChecker([])
|
||||||
|
assert sc.check("hello") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenContext:
|
||||||
|
def test_defaults(self):
|
||||||
|
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||||
|
assert ctx.completion_tokens == 0
|
||||||
|
|
||||||
|
def test_fields_mutable(self):
|
||||||
|
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||||
|
ctx.completion_tokens = 42
|
||||||
|
assert ctx.completion_tokens == 42
|
||||||
|
|
||||||
|
|
||||||
|
class TestStopInfo:
|
||||||
|
def test_defaults(self):
|
||||||
|
s = StopInfo()
|
||||||
|
assert s.matched is None
|
||||||
|
assert s.body == ""
|
||||||
|
assert s.yielded == ""
|
||||||
|
|
||||||
|
def test_with_values(self):
|
||||||
|
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
|
||||||
|
assert s.matched == "stop"
|
||||||
|
assert s.body == "hello stop"
|
||||||
|
assert s.yielded == "hello "
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIResponseBuilder:
|
||||||
|
@pytest.fixture
|
||||||
|
def builder(self):
|
||||||
|
builder = OpenAIResponseBuilder()
|
||||||
|
req = MagicMock()
|
||||||
|
req.messages = [MagicMock(role="user", content="Hello")]
|
||||||
|
req.stop = None
|
||||||
|
req.model = "astrai"
|
||||||
|
engine = MagicMock()
|
||||||
|
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||||
|
builder.prepare(req, engine)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
def test_prepare_returns_prompt_ctx_stops(self, builder):
|
||||||
|
req = MagicMock()
|
||||||
|
req.messages = [MagicMock(role="user", content="Hi")]
|
||||||
|
req.stop = ["END"]
|
||||||
|
req.model = "gpt"
|
||||||
|
engine = MagicMock()
|
||||||
|
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||||
|
prompt, ctx, stops = builder.prepare(req, engine)
|
||||||
|
assert prompt == "Hi"
|
||||||
|
assert ctx.model == "gpt"
|
||||||
|
assert ctx.prompt_tokens == 0
|
||||||
|
assert stops == ["END"]
|
||||||
|
|
||||||
|
def test_prepare_no_stop_returns_empty_list(self, builder):
|
||||||
|
req = MagicMock()
|
||||||
|
req.messages = []
|
||||||
|
req.stop = None
|
||||||
|
req.model = "x"
|
||||||
|
engine = MagicMock()
|
||||||
|
engine.tokenizer.apply_chat_template.return_value = ""
|
||||||
|
_, _, stops = builder.prepare(req, engine)
|
||||||
|
assert stops == []
|
||||||
|
|
||||||
|
def test_format_stream_start(self, builder):
|
||||||
|
ctx = _make_ctx()
|
||||||
|
events = builder.format_stream_start(ctx)
|
||||||
|
payloads = _sse_payloads(events)
|
||||||
|
assert len(payloads) == 1
|
||||||
|
p = payloads[0]
|
||||||
|
assert p["object"] == "chat.completion.chunk"
|
||||||
|
assert p["choices"][0]["delta"]["role"] == "assistant"
|
||||||
|
assert p["choices"][0]["finish_reason"] is None
|
||||||
|
|
||||||
|
def test_format_chunk(self, builder):
|
||||||
|
event = builder.format_chunk("hello")
|
||||||
|
payload = json.loads(event.split("data: ", 1)[1])
|
||||||
|
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||||
|
assert payload["choices"][0]["finish_reason"] is None
|
||||||
|
|
||||||
|
def test_format_stream_end(self, builder):
|
||||||
|
ctx = _make_ctx(completion_tokens=5)
|
||||||
|
stop = StopInfo(matched="stop")
|
||||||
|
events = builder.format_stream_end(ctx, stop)
|
||||||
|
payloads = _sse_payloads(events)
|
||||||
|
finish = payloads[0]
|
||||||
|
assert finish["choices"][0]["finish_reason"] == "stop"
|
||||||
|
usage = payloads[1]
|
||||||
|
assert usage["completion_tokens"] == 5
|
||||||
|
assert usage["total_tokens"] == 15
|
||||||
|
|
||||||
|
def test_format_response(self, builder):
|
||||||
|
ctx = _make_ctx()
|
||||||
|
stop = StopInfo()
|
||||||
|
resp = builder.format_response(ctx, "hello", stop)
|
||||||
|
assert resp["object"] == "chat.completion"
|
||||||
|
assert resp["choices"][0]["message"]["content"] == "hello"
|
||||||
|
assert resp["usage"]["prompt_tokens"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicResponseBuilder:
|
||||||
|
@pytest.fixture
|
||||||
|
def builder(self):
|
||||||
|
builder = AnthropicResponseBuilder()
|
||||||
|
req = MagicMock()
|
||||||
|
req.messages = [MagicMock(role="user", content="Hello")]
|
||||||
|
req.model = "claude"
|
||||||
|
engine = MagicMock()
|
||||||
|
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||||
|
req.system = None
|
||||||
|
builder.prepare(req, engine)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
def test_prepare_messages(self, builder):
|
||||||
|
req = MagicMock()
|
||||||
|
req.messages = [MagicMock(role="user", content="Hi")]
|
||||||
|
req.model = "claude"
|
||||||
|
req.system = None
|
||||||
|
req.stop_sequences = None
|
||||||
|
engine = MagicMock()
|
||||||
|
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||||
|
prompt, ctx, stops = builder.prepare(req, engine)
|
||||||
|
assert prompt == "Hi"
|
||||||
|
assert stops == []
|
||||||
|
|
||||||
|
def test_prepare_with_stop_sequences(self, builder):
|
||||||
|
req = MagicMock()
|
||||||
|
req.messages = []
|
||||||
|
req.model = "x"
|
||||||
|
req.stop_sequences = ["stop", "end"]
|
||||||
|
req.system = None
|
||||||
|
engine = MagicMock()
|
||||||
|
engine.tokenizer.apply_chat_template.return_value = ""
|
||||||
|
_, _, stops = builder.prepare(req, engine)
|
||||||
|
assert stops == ["stop", "end"]
|
||||||
|
|
||||||
|
def test_format_stream_start(self, builder):
|
||||||
|
ctx = _make_ctx(prompt_tokens=3)
|
||||||
|
events = builder.format_stream_start(ctx)
|
||||||
|
payloads = _sse_payloads(events)
|
||||||
|
assert len(payloads) == 2
|
||||||
|
assert payloads[0]["type"] == "message_start"
|
||||||
|
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
|
||||||
|
assert payloads[1]["type"] == "content_block_start"
|
||||||
|
|
||||||
|
def test_format_chunk(self, builder):
|
||||||
|
event = builder.format_chunk("tok")
|
||||||
|
payload = json.loads(event.split("data: ", 1)[1])
|
||||||
|
assert payload["type"] == "content_block_delta"
|
||||||
|
assert payload["delta"]["text"] == "tok"
|
||||||
|
|
||||||
|
def test_format_stream_end_no_stop(self, builder):
|
||||||
|
ctx = _make_ctx(completion_tokens=3)
|
||||||
|
stop = StopInfo()
|
||||||
|
events = builder.format_stream_end(ctx, stop)
|
||||||
|
payloads = _sse_payloads(events)
|
||||||
|
# content_block_stop, message_delta, message_stop
|
||||||
|
types = [p["type"] for p in payloads]
|
||||||
|
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||||
|
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
|
||||||
|
|
||||||
|
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
|
||||||
|
ctx = _make_ctx(completion_tokens=7)
|
||||||
|
stop = StopInfo(
|
||||||
|
matched="END",
|
||||||
|
body="Hello world END extra",
|
||||||
|
yielded="Hello ",
|
||||||
|
)
|
||||||
|
events = builder.format_stream_end(ctx, stop)
|
||||||
|
payloads = _sse_payloads(events)
|
||||||
|
# unyielded delta, content_block_stop, message_delta, message_stop
|
||||||
|
types = [p["type"] for p in payloads]
|
||||||
|
assert types == [
|
||||||
|
"content_block_delta",
|
||||||
|
"content_block_stop",
|
||||||
|
"message_delta",
|
||||||
|
"message_stop",
|
||||||
|
]
|
||||||
|
assert payloads[0]["delta"]["text"] == "world "
|
||||||
|
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
|
||||||
|
assert payloads[2]["delta"]["stop_sequence"] == "END"
|
||||||
|
|
||||||
|
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
|
||||||
|
ctx = _make_ctx()
|
||||||
|
stop = StopInfo(
|
||||||
|
matched="END",
|
||||||
|
body="Hello END",
|
||||||
|
yielded="Hello ",
|
||||||
|
)
|
||||||
|
events = builder.format_stream_end(ctx, stop)
|
||||||
|
payloads = _sse_payloads(events)
|
||||||
|
# No unyielded delta (everything already sent)
|
||||||
|
types = [p["type"] for p in payloads]
|
||||||
|
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||||
|
|
||||||
|
def test_format_response_with_stop_trims_content(self, builder):
|
||||||
|
ctx = _make_ctx()
|
||||||
|
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
|
||||||
|
resp = builder.format_response(ctx, "text STOP extra", stop)
|
||||||
|
assert resp["content"][0]["text"] == "text "
|
||||||
|
assert resp["stop_reason"] == "stop_sequence"
|
||||||
|
assert resp["stop_sequence"] == "STOP"
|
||||||
|
|
||||||
|
def test_format_response_no_stop(self, builder):
|
||||||
|
ctx = _make_ctx()
|
||||||
|
stop = StopInfo()
|
||||||
|
resp = builder.format_response(ctx, "full text", stop)
|
||||||
|
assert resp["content"][0]["text"] == "full text"
|
||||||
|
assert resp["stop_reason"] == "end_turn"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerationRequestValidation:
|
||||||
|
def test_valid_params(self):
|
||||||
|
gr = GenerationRequest(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9,
|
||||||
|
temperature=0.7,
|
||||||
|
)
|
||||||
|
assert gr.top_k == 50
|
||||||
|
|
||||||
|
def test_invalid_top_p_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="top_p"):
|
||||||
|
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
|
||||||
|
|
||||||
|
def test_invalid_top_k_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="top_k"):
|
||||||
|
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
|
||||||
|
|
||||||
|
def test_invalid_temperature_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="temperature"):
|
||||||
|
GenerationRequest(
|
||||||
|
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_top_k_zero_valid(self):
|
||||||
|
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
|
||||||
|
assert gr.top_k == 0
|
||||||
|
|
@ -173,3 +173,21 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
for stats in results["stats"]:
|
for stats in results["stats"]:
|
||||||
assert "total_tasks" in stats
|
assert "total_tasks" in stats
|
||||||
assert stats["total_tasks"] >= 0
|
assert stats["total_tasks"] >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
|
||||||
|
"""Tasks whose entire prompt is cached skip the prefill phase."""
|
||||||
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||||
|
|
||||||
|
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||||
|
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||||
|
scheduler = InferenceScheduler(
|
||||||
|
model=mock_model,
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
max_batch_size=4,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
|
||||||
|
scheduler.stop()
|
||||||
|
assert task_id.startswith("task_")
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from astrai.inference import app
|
from astrai.inference import get_app
|
||||||
|
|
||||||
|
|
||||||
def test_health_no_model(client):
|
def test_health_no_model(client):
|
||||||
"""GET /health should return 200 even when engine not loaded."""
|
"""GET /health should return 200 even when engine not loaded."""
|
||||||
app.state.engine = None
|
get_app().state.engine = None
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Reply"
|
yield "Reply"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -157,5 +157,60 @@ def test_messages_with_system(client, loaded_model):
|
||||||
assert data["type"] == "message"
|
assert data["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completions_stop_sequence(client, loaded_model):
|
||||||
|
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
yield "Hello"
|
||||||
|
yield "X"
|
||||||
|
yield "world"
|
||||||
|
|
||||||
|
get_app().state.engine = loaded_model
|
||||||
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": False,
|
||||||
|
"stop": ["X"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
content = data["choices"][0]["message"]["content"]
|
||||||
|
assert "X" in content
|
||||||
|
assert "world" not in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completions_stop_sequence_stream(client, loaded_model):
|
||||||
|
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
yield "Hello"
|
||||||
|
yield "X"
|
||||||
|
yield "world"
|
||||||
|
|
||||||
|
get_app().state.engine = loaded_model
|
||||||
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": True,
|
||||||
|
"stop": ["X"],
|
||||||
|
},
|
||||||
|
headers={"Accept": "text/event-stream"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
content = response.content.decode("utf-8")
|
||||||
|
assert "Hello" in content
|
||||||
|
assert "world" not in content
|
||||||
|
assert any(
|
||||||
|
"finish_reason" in line for line in content.split("\n") if "stop" in line
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,166 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.config.model_config import EncoderConfig
|
||||||
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
|
||||||
|
TINY_CONFIG = dict(
|
||||||
|
vocab_size=128,
|
||||||
|
dim=8,
|
||||||
|
n_heads=2,
|
||||||
|
n_kv_heads=1,
|
||||||
|
dim_ffn=16,
|
||||||
|
max_len=64,
|
||||||
|
n_layers=2,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_mean():
|
||||||
|
config = EncoderConfig(**TINY_CONFIG)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_cls():
|
||||||
|
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_last():
|
||||||
|
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_with_padding():
|
||||||
|
config = EncoderConfig(**TINY_CONFIG)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||||
|
input_mask[:, 4:] = False
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids, input_mask=input_mask)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_normalize():
|
||||||
|
config = EncoderConfig(
|
||||||
|
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
|
||||||
|
)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
norms = output.norm(p=2, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_register():
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
|
||||||
|
assert AutoModel.is_registered("embedding")
|
||||||
|
cls = AutoModel.get_component_class("embedding")
|
||||||
|
assert cls is EmbeddingEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_from_transformer_checkpoint():
|
||||||
|
config = EncoderConfig(**TINY_CONFIG)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
state_dict["lm_head.weight"] = torch.randn(
|
||||||
|
config.vocab_size, config.dim, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
new_model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
for key in model.state_dict():
|
||||||
|
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_save_load():
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
|
|
||||||
|
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
|
||||||
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
weights_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
|
||||||
|
config = EncoderConfig.from_file(config_path)
|
||||||
|
original = EmbeddingEncoder(config)
|
||||||
|
st.save_file(original.state_dict(), weights_path)
|
||||||
|
|
||||||
|
loaded = EmbeddingEncoder(config)
|
||||||
|
loaded.load_state_dict(st.load_file(weights_path))
|
||||||
|
|
||||||
|
for key in original.state_dict():
|
||||||
|
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
|
||||||
|
finally:
|
||||||
|
if os.path.exists(test_dir):
|
||||||
|
for f in os.listdir(test_dir):
|
||||||
|
os.remove(os.path.join(test_dir, f))
|
||||||
|
os.rmdir(test_dir)
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
|
TINY_CONFIG = dict(
|
||||||
|
vocab_size=128,
|
||||||
|
dim=8,
|
||||||
|
n_heads=2,
|
||||||
|
n_kv_heads=1,
|
||||||
|
dim_ffn=16,
|
||||||
|
max_len=64,
|
||||||
|
n_layers=2,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGS = [
|
||||||
|
pytest.param(
|
||||||
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
|
||||||
|
id="gqa_mlp",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
**TINY_CONFIG,
|
||||||
|
"attn_type": "mla",
|
||||||
|
"ffn_type": "mlp",
|
||||||
|
"kv_lora_rank": 4,
|
||||||
|
"qk_nope_head_dim": 2,
|
||||||
|
"qk_rope_head_dim": 2,
|
||||||
|
},
|
||||||
|
id="mla_mlp",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
**TINY_CONFIG,
|
||||||
|
"attn_type": "gqa",
|
||||||
|
"ffn_type": "moe",
|
||||||
|
"n_routed_experts": 4,
|
||||||
|
"n_shared_experts": 1,
|
||||||
|
"n_activated_experts": 2,
|
||||||
|
"topk_method": "greedy",
|
||||||
|
},
|
||||||
|
id="gqa_moe",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
**TINY_CONFIG,
|
||||||
|
"attn_type": "gqa",
|
||||||
|
"ffn_type": "mlp",
|
||||||
|
"rope_theta": 100000.0,
|
||||||
|
},
|
||||||
|
id="gqa_rope_theta",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
|
||||||
|
id="gqa_qk_norm",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
|
||||||
|
id="gqa_tie_weight",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||||
|
def test_model_forward(config_kwargs):
|
||||||
|
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = AutoRegressiveLM(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert "logits" in output
|
||||||
|
assert "hidden_states" in output
|
||||||
|
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||||
|
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
|
||||||
|
assert not torch.isnan(output["logits"]).any()
|
||||||
|
assert not torch.isnan(output["hidden_states"]).any()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||||
|
def test_model_forward_with_padding(config_kwargs):
|
||||||
|
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = AutoRegressiveLM(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||||
|
input_mask[:, 4:] = False
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids, input_mask=input_mask)
|
||||||
|
|
||||||
|
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||||
|
assert not torch.isnan(output["logits"]).any()
|
||||||
|
|
@ -0,0 +1,355 @@
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
|
from astrai.model import AutoRegressiveLM
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.lora import (
|
||||||
|
LoRAConfig,
|
||||||
|
LoRALinear,
|
||||||
|
_collect_lora_info,
|
||||||
|
_get_lora_count,
|
||||||
|
inject_lora,
|
||||||
|
load_lora,
|
||||||
|
merge_lora,
|
||||||
|
save_lora,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_KWARGS = dict(
|
||||||
|
vocab_size=1000,
|
||||||
|
dim=64,
|
||||||
|
n_heads=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
dim_ffn=128,
|
||||||
|
n_layers=2,
|
||||||
|
max_len=32,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model(**kwargs):
|
||||||
|
kw = {**MODEL_KWARGS, **kwargs}
|
||||||
|
config = AutoRegressiveLMConfig(**kw)
|
||||||
|
model = AutoRegressiveLM(config)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_init():
|
||||||
|
base = Linear(64, 128)
|
||||||
|
lora = LoRALinear(base, r=8, alpha=16)
|
||||||
|
|
||||||
|
assert lora.weight is base.weight
|
||||||
|
assert not lora.weight.requires_grad
|
||||||
|
assert lora.lora_A.shape == (8, 64)
|
||||||
|
assert lora.lora_B.shape == (128, 8)
|
||||||
|
assert lora.scaling == 2.0
|
||||||
|
assert not lora._merged
|
||||||
|
assert lora.lora_A.requires_grad
|
||||||
|
assert lora.lora_B.requires_grad
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_forward_init_zero_delta():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
x = torch.randn(2, 4)
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
base_out = base(x)
|
||||||
|
lora_out = lora(x)
|
||||||
|
|
||||||
|
assert torch.allclose(base_out, lora_out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_forward_with_delta():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
x = torch.randn(2, 4)
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
base_out = base(x)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
lora.lora_B.fill_(1.0)
|
||||||
|
|
||||||
|
lora_out = lora(x)
|
||||||
|
assert not torch.allclose(base_out, lora_out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_merge():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
x = torch.randn(2, 4)
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
with torch.no_grad():
|
||||||
|
lora.lora_B.fill_(1.0)
|
||||||
|
|
||||||
|
out_before = lora(x).clone()
|
||||||
|
lora.merge()
|
||||||
|
out_after = lora(x)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_before, out_after)
|
||||||
|
assert lora._merged
|
||||||
|
assert not hasattr(lora, "lora_A")
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_merge_is_idempotent():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
with torch.no_grad():
|
||||||
|
lora.lora_B.fill_(1.0)
|
||||||
|
|
||||||
|
lora.merge()
|
||||||
|
lora.merge()
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_default_target():
|
||||||
|
model = _make_model()
|
||||||
|
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8)
|
||||||
|
|
||||||
|
lora_count = _get_lora_count(model)
|
||||||
|
assert lora_count > 0
|
||||||
|
assert lora_count < n_before
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_ffn():
|
||||||
|
model = _make_model()
|
||||||
|
from astrai.model.components.lora import TARGET_MODULES_FFN
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
|
||||||
|
assert _get_lora_count(model) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_returns_config():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=8, alpha=32)
|
||||||
|
assert isinstance(cfg, LoRAConfig)
|
||||||
|
assert cfg.r == 8
|
||||||
|
assert cfg.alpha == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_no_matching_targets_warns(caplog):
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
|
||||||
|
assert "No LoRA layers injected" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_preserves_base_output():
|
||||||
|
model = _make_model()
|
||||||
|
x = torch.randint(0, 1000, (2, 16))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_before = model(x)["logits"].clone()
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_after = model(x)["logits"]
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_before, out_after)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_does_not_reinject():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
first_count = _get_lora_count(model)
|
||||||
|
|
||||||
|
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
|
||||||
|
assert _get_lora_count(model) == first_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_adds_new_modules():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
first = _get_lora_count(model)
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
|
||||||
|
assert _get_lora_count(model) > first
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_on_mla_model():
|
||||||
|
model = _make_model(
|
||||||
|
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
|
||||||
|
)
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
|
||||||
|
assert _get_lora_count(model) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_on_moe_model():
|
||||||
|
model = _make_model(
|
||||||
|
ffn_type="moe",
|
||||||
|
n_routed_experts=4,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_activated_experts=2,
|
||||||
|
dim_ffn=32,
|
||||||
|
)
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
|
||||||
|
assert _get_lora_count(model) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict_key_format():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
sd = model.state_dict()
|
||||||
|
assert "layers.0.attention.q_proj.weight" in sd
|
||||||
|
assert "layers.0.attention.q_proj.lora_A" in sd
|
||||||
|
assert "layers.0.attention.q_proj.lora_B" in sd
|
||||||
|
|
||||||
|
|
||||||
|
def test_only_lora_params_trainable():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if isinstance(name.split(".")[-1], str) and "lora" in name:
|
||||||
|
assert param.requires_grad, f"lora param should be trainable: {name}"
|
||||||
|
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
|
||||||
|
assert not param.requires_grad, f"injected weight should be frozen: {name}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict_after_inject_consistent_with_original():
|
||||||
|
model = _make_model()
|
||||||
|
sd_before = {k: v for k, v in model.state_dict().items()}
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
sd_after = model.state_dict()
|
||||||
|
|
||||||
|
# original keys unchanged
|
||||||
|
for k in sd_before:
|
||||||
|
assert k in sd_after
|
||||||
|
assert sd_before[k].shape == sd_after[k].shape
|
||||||
|
|
||||||
|
# new lora keys present
|
||||||
|
lora_keys = [k for k in sd_after if "lora" in k]
|
||||||
|
assert len(lora_keys) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_roundtrip():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
x = torch.randint(0, 1000, (2, 16))
|
||||||
|
with torch.no_grad():
|
||||||
|
out_src = model(x)["logits"].clone()
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, cfg)
|
||||||
|
|
||||||
|
model2 = _make_model()
|
||||||
|
model2.load_state_dict(model.state_dict(), strict=False)
|
||||||
|
load_lora(model2, tmpdir)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_dst = model2(x)["logits"]
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_src, out_dst)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_after_merge_raises():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, cfg)
|
||||||
|
merge_lora(model)
|
||||||
|
|
||||||
|
tmpdir2 = tempfile.mkdtemp()
|
||||||
|
with pytest.raises(RuntimeError, match="No LoRA parameters"):
|
||||||
|
save_lora(model, tmpdir2, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_lora_on_already_injected():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
|
||||||
|
|
||||||
|
model2 = _make_model()
|
||||||
|
model2.load_state_dict(model.state_dict(), strict=False)
|
||||||
|
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
# load onto already-injected model
|
||||||
|
load_lora(model2, tmpdir)
|
||||||
|
assert _get_lora_count(model2) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_lora_mismatched_r_raises():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, cfg)
|
||||||
|
|
||||||
|
model2 = _make_model()
|
||||||
|
model2.load_state_dict(model.state_dict(), strict=False)
|
||||||
|
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="size mismatch"):
|
||||||
|
load_lora(model2, tmpdir) # strict=False, only lora keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_preserves_output():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
x = torch.randint(0, 1000, (2, 16))
|
||||||
|
with torch.no_grad():
|
||||||
|
out_before = model(x)["logits"].clone()
|
||||||
|
|
||||||
|
merge_lora(model)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_after = model(x)["logits"]
|
||||||
|
torch.testing.assert_close(out_before, out_after)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_no_lora_warns(caplog):
|
||||||
|
model = _make_model()
|
||||||
|
merge_lora(model)
|
||||||
|
assert "No LoRA layers to merge" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_lora_info():
|
||||||
|
model = _make_model()
|
||||||
|
info = _collect_lora_info(model)
|
||||||
|
assert "q_proj" in info
|
||||||
|
assert "o_proj" in info
|
||||||
|
assert "q_proj" in info # each layer has one
|
||||||
|
|
@ -6,8 +6,8 @@ import pytest
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -17,10 +17,10 @@ def transformer_test_env():
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 128,
|
"dim": 8,
|
||||||
"n_heads": 4,
|
"n_heads": 2,
|
||||||
"n_kv_heads": 2,
|
"n_kv_heads": 1,
|
||||||
"dim_ffn": 256,
|
"dim_ffn": 16,
|
||||||
"max_len": 64,
|
"max_len": 64,
|
||||||
"n_layers": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5,
|
"norm_eps": 1e-5,
|
||||||
|
|
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(config)
|
model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||||
|
|
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(config)
|
model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||||
|
|
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
original_model = Transformer(config)
|
original_model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
st.save_file(original_model.state_dict(), model_path)
|
st.save_file(original_model.state_dict(), model_path)
|
||||||
|
|
||||||
loaded_config = ModelConfig().load(config_path)
|
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = AutoRegressiveLM(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
|
|
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
loaded_config = ModelConfig().load(config_path)
|
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = AutoRegressiveLM(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
@ -25,14 +27,14 @@ class TrainerDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
def create_train_config(
|
def create_train_config(
|
||||||
model: torch.nn.Module,
|
model_fn,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
test_dir: str,
|
test_dir: str,
|
||||||
device: str,
|
device: str,
|
||||||
strategy: str = "seq",
|
strategy: str = "seq",
|
||||||
n_epoch: int = 1,
|
n_epoch: int = 1,
|
||||||
batch_size: int = 2,
|
batch_per_device: int = 2,
|
||||||
accumulation_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
max_grad_norm: float = 1.0,
|
max_grad_norm: float = 1.0,
|
||||||
ckpt_interval: int = 5,
|
ckpt_interval: int = 5,
|
||||||
random_seed: int = 42,
|
random_seed: int = 42,
|
||||||
|
|
@ -41,14 +43,14 @@ def create_train_config(
|
||||||
"""Factory function to create common TrainConfig for tests.
|
"""Factory function to create common TrainConfig for tests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model to train
|
model_fn: Model factory (callable returning nn.Module)
|
||||||
dataset: Training dataset
|
dataset: Training dataset
|
||||||
test_dir: Checkpoint directory
|
test_dir: Checkpoint directory
|
||||||
device: Device type ("cuda" or "cpu")
|
device: Device type ("cuda" or "cpu")
|
||||||
strategy: Training strategy type (default: "seq")
|
strategy: Training strategy type (default: "seq")
|
||||||
n_epoch: Number of epochs (default: 1)
|
n_epoch: Number of epochs (default: 1)
|
||||||
batch_size: Batch size (default: 2)
|
batch_per_device: Batch size per device (default: 2)
|
||||||
accumulation_steps: Gradient accumulation steps (default: 1)
|
grad_accum_steps: Gradient accumulation steps (default: 1)
|
||||||
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
||||||
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
||||||
random_seed: Random seed for reproducibility (default: 42)
|
random_seed: Random seed for reproducibility (default: 42)
|
||||||
|
|
@ -68,15 +70,16 @@ def create_train_config(
|
||||||
|
|
||||||
return TrainConfig(
|
return TrainConfig(
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
model=model,
|
model_fn=model_fn,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=test_dir,
|
ckpt_dir=test_dir,
|
||||||
|
log_dir=os.path.join(test_dir, "logs"),
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_per_device=batch_per_device,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
accumulation_steps=accumulation_steps,
|
grad_accum_steps=grad_accum_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
device_type=device,
|
device_type=device,
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,133 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.trainer.schedule import SchedulerFactory
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
from astrai.trainer.train_callback import TrainCallback
|
from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_enable_disable(test_model):
|
||||||
|
"""Enable wraps forward, _disable restores it."""
|
||||||
|
model = test_model["model"]
|
||||||
|
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||||
|
|
||||||
|
originals = [layer.forward for layer in model.layers]
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
assert hasattr(layer, "_original_forward")
|
||||||
|
assert layer.forward is not originals[0]
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._disable(layer)
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
assert not hasattr(layer, "_original_forward")
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_empty_modules_noop(test_model):
|
||||||
|
"""modules=None should leave forwards untouched."""
|
||||||
|
model = test_model["model"]
|
||||||
|
callback = GradientCheckpointingCallback()
|
||||||
|
|
||||||
|
originals = [layer.forward for layer in model.layers]
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
for layer, orig in zip(model.layers, originals):
|
||||||
|
assert layer.forward is orig
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_forward_unchanged(test_model):
|
||||||
|
"""Forward output unchanged after patching (no_grad)."""
|
||||||
|
model = test_model["model"]
|
||||||
|
device = test_model["device"]
|
||||||
|
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||||
|
|
||||||
|
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref = model(input_ids)["logits"].clone()
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(input_ids)["logits"]
|
||||||
|
|
||||||
|
assert torch.equal(ref, out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_backward(test_model):
|
||||||
|
"""backward passes gradients through checkpointed layers."""
|
||||||
|
model = test_model["model"]
|
||||||
|
device = test_model["device"]
|
||||||
|
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||||
|
target_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||||
|
|
||||||
|
logits = model(input_ids)["logits"]
|
||||||
|
loss = torch.nn.functional.cross_entropy(
|
||||||
|
logits.flatten(0, 1).float(), target_ids.flatten()
|
||||||
|
)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
assert param.grad is not None, f"{name} gradient is None"
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._disable(layer)
|
||||||
|
|
||||||
|
model.zero_grad()
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset):
|
||||||
|
"""Gradient checkpointing runs end-to-end via Trainer."""
|
||||||
|
|
||||||
|
def optimizer_fn(model):
|
||||||
|
return torch.optim.AdamW(model.parameters())
|
||||||
|
|
||||||
|
def scheduler_fn(optim):
|
||||||
|
return SchedulerFactory.create(
|
||||||
|
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
|
)
|
||||||
|
|
||||||
|
train_config = TrainConfig(
|
||||||
|
model_fn=lambda: base_test_env["model"],
|
||||||
|
strategy="seq",
|
||||||
|
dataset=random_dataset,
|
||||||
|
optimizer_fn=optimizer_fn,
|
||||||
|
scheduler_fn=scheduler_fn,
|
||||||
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
|
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||||
|
n_epoch=1,
|
||||||
|
batch_per_device=2,
|
||||||
|
ckpt_interval=3,
|
||||||
|
grad_accum_steps=1,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
random_seed=42,
|
||||||
|
device_type=base_test_env["device"],
|
||||||
|
gradient_checkpointing_modules=[DecoderBlock],
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(train_config)
|
||||||
|
trainer.train()
|
||||||
|
# no crash = callback correctly enabled/disabled
|
||||||
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
|
|
||||||
|
|
@ -18,16 +140,17 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
|
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=3,
|
ckpt_interval=3,
|
||||||
accumulation_steps=1,
|
grad_accum_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"],
|
device_type=base_test_env["device"],
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.trainer.schedule import SchedulerFactory
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
@ -24,13 +23,14 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
|
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=1,
|
ckpt_interval=1,
|
||||||
accumulation_steps=2,
|
grad_accum_steps=2,
|
||||||
random_seed=np.random.randint(1e4),
|
random_seed=np.random.randint(1e4),
|
||||||
device_type=base_test_env["device"],
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
@ -38,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
||||||
# Should handle early stopping gracefully
|
# Should handle early stopping gracefully
|
||||||
checkpoint = None
|
|
||||||
try:
|
try:
|
||||||
checkpoint = trainer.train()
|
trainer.train()
|
||||||
except Exception:
|
except Exception:
|
||||||
# Handle any exceptions
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Resume from latest checkpoint
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||||
checkpoint = Checkpoint.load(load_dir)
|
trainer = Trainer(train_config)
|
||||||
trainer.train(checkpoint)
|
trainer.train(resume_dir=load_dir)
|
||||||
|
|
||||||
|
# Verify checkpoint was saved at expected iteration
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||||
checkpoint = Checkpoint.load(load_dir)
|
import json
|
||||||
assert checkpoint.iteration == 10
|
|
||||||
|
with open(os.path.join(load_dir, "meta.json")) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert meta["iteration"] == 10
|
||||||
|
|
|
||||||
|
|
@ -7,55 +7,56 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
||||||
"""Test training with different batch sizes"""
|
"""Test training with different batch sizes"""
|
||||||
batch_sizes = [1, 2, 4, 8]
|
batch_sizes = [1, 2, 4, 8]
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
for batch_per_device in batch_sizes:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_size=batch_size,
|
batch_per_device=batch_per_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.batch_size == batch_size
|
assert train_config.batch_per_device == batch_per_device
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
||||||
"""Test training with different gradient accumulation steps"""
|
"""Test training with different gradient accumulation steps"""
|
||||||
accumulation_steps_list = [1, 2, 4]
|
grad_accum_steps_list = [1, 2, 4]
|
||||||
|
|
||||||
for accumulation_steps in accumulation_steps_list:
|
for grad_accum_steps in grad_accum_steps_list:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_size=2,
|
batch_per_device=2,
|
||||||
accumulation_steps=accumulation_steps,
|
grad_accum_steps=grad_accum_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
assert train_config.accumulation_steps == accumulation_steps
|
assert train_config.grad_accum_steps == grad_accum_steps
|
||||||
|
|
||||||
|
|
||||||
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
||||||
"""Test training with memory-efficient configurations"""
|
"""Test training with memory-efficient configurations"""
|
||||||
# Test with smaller batch sizes and gradient checkpointing
|
# Test with smaller batch sizes and gradient checkpointing
|
||||||
small_batch_configs = [
|
small_batch_configs = [
|
||||||
{"batch_size": 1, "accumulation_steps": 8},
|
{"batch_per_device": 1, "grad_accum_steps": 8},
|
||||||
{"batch_size": 2, "accumulation_steps": 4},
|
{"batch_per_device": 2, "grad_accum_steps": 4},
|
||||||
{"batch_size": 4, "accumulation_steps": 2},
|
{"batch_per_device": 4, "grad_accum_steps": 2},
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_size=config["batch_size"],
|
batch_per_device=config["batch_per_device"],
|
||||||
accumulation_steps=config["accumulation_steps"],
|
grad_accum_steps=config["grad_accum_steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
assert train_config.grad_accum_steps == config["grad_accum_steps"]
|
||||||
|
assert train_config.batch_per_device == config["batch_per_device"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue