Compare commits
No commits in common. "main" and "v1.3.2" have entirely different histories.
|
|
@ -1,9 +0,0 @@
|
||||||
# Ignore everything
|
|
||||||
*
|
|
||||||
|
|
||||||
# Allow necessary files
|
|
||||||
!astrai/
|
|
||||||
!scripts/
|
|
||||||
!assets/
|
|
||||||
!pyproject.toml
|
|
||||||
!README.md
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
# Auto detect text files
|
|
||||||
* text=auto
|
|
||||||
|
|
||||||
# Files that MUST use LF (Unix/Linux execution)
|
|
||||||
*.sh text eol=lf
|
|
||||||
*.py text eol=lf
|
|
||||||
*.md text eol=lf
|
|
||||||
*.yml text eol=lf
|
|
||||||
|
|
||||||
Dockerfile text eol=lf
|
|
||||||
.dockerignore text eol=lf
|
|
||||||
|
|
||||||
.gitignore text eol=lf
|
|
||||||
.gitattributes text eol=lf
|
|
||||||
|
|
||||||
# Windows scripts - use CRLF
|
|
||||||
*.bat text eol=crlf
|
|
||||||
*.cmd text eol=crlf
|
|
||||||
*.ps1 text eol=crlf
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Create a report to help us improve
|
|
||||||
title: "[BUG]"
|
|
||||||
labels: bug
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Description
|
|
||||||
A clear and concise description of what the bug is.
|
|
||||||
## Steps to Reproduce
|
|
||||||
1. ...
|
|
||||||
2. ...
|
|
||||||
3. ...
|
|
||||||
## Expected Behavior
|
|
||||||
What you expected to happen.
|
|
||||||
## Actual Behavior
|
|
||||||
What actually happened.
|
|
||||||
## Environment
|
|
||||||
- Python version:
|
|
||||||
- AstrAI version (or commit hash):
|
|
||||||
- Operating System:
|
|
||||||
- GPU (if applicable):
|
|
||||||
- CUDA/cuDNN version (if applicable):
|
|
||||||
## Additional Context
|
|
||||||
Add any other context, screenshots, or logs here.
|
|
||||||
|
|
@ -1,10 +0,0 @@
|
||||||
---
|
|
||||||
name: Custom issue template
|
|
||||||
about: Describe this issue template's purpose here.
|
|
||||||
title: ''
|
|
||||||
labels: ''
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: "[FEAT]"
|
|
||||||
labels: ''
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Description
|
|
||||||
A clear and concise description of the feature you'd like to see.
|
|
||||||
## Problem Statement
|
|
||||||
What problem does this feature solve? Why is it needed?
|
|
||||||
## Proposed Solution
|
|
||||||
Describe the solution you'd like. Include any design ideas, API changes, or implementation details.
|
|
||||||
## Alternatives Considered
|
|
||||||
Describe any alternative solutions or features you've considered.
|
|
||||||
## Additional Context
|
|
||||||
Add any other context, screenshots, or references here.
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
## Description
|
|
||||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context.
|
|
||||||
|
|
||||||
Fixes # (issue number)
|
|
||||||
|
|
||||||
## Type of Change
|
|
||||||
Please delete options that are not relevant.
|
|
||||||
|
|
||||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
|
||||||
- [ ] New feature (non-breaking change which adds functionality)
|
|
||||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
|
||||||
- [ ] Documentation update
|
|
||||||
- [ ] Other (please describe):
|
|
||||||
|
|
||||||
## How Has This Been Tested?
|
|
||||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
|
|
||||||
|
|
||||||
## Checklist:
|
|
||||||
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
|
|
||||||
- [ ] I have performed a self-review of my own code
|
|
||||||
- [ ] Code is self-documenting (no unnecessary comments)
|
|
||||||
- [ ] I have made corresponding changes to the documentation
|
|
||||||
- [ ] My changes generate no new warnings
|
|
||||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
|
||||||
- [ ] New and existing unit tests pass locally with my changes
|
|
||||||
- [ ] Any dependent changes have been merged and published in downstream modules
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
name: Build and Push Docker Image
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
packages: write
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v3
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Login to GitHub Container Registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Extract metadata
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: ghcr.io/${{ github.repository }}
|
|
||||||
tags: |
|
|
||||||
type=ref,event=tag
|
|
||||||
type=raw,value=latest
|
|
||||||
|
|
||||||
- name: Build and push
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
platforms: linux/amd64
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
cache-from: type=gha
|
|
||||||
cache-to: type=gha,mode=max
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
name: Lint
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [main]
|
|
||||||
pull_request:
|
|
||||||
branches: [main]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
lint:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python 3.12
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: "3.12"
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install .[dev]
|
|
||||||
|
|
||||||
- name: Check formatting with ruff
|
|
||||||
run: |
|
|
||||||
ruff format --check .
|
|
||||||
|
|
||||||
- name: Check import sorting
|
|
||||||
run: |
|
|
||||||
ruff check . --select I
|
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
name: Spell Check
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
spellcheck:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Check spelling in specific files
|
||||||
|
uses: codespell-project/actions-codespell@v2
|
||||||
|
with:
|
||||||
|
check_filenames: true
|
||||||
|
only_warn: false
|
||||||
|
path: "**/*.{md, py}"
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
name: Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [main]
|
|
||||||
pull_request:
|
|
||||||
branches: [main]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.12"]
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install .[dev]
|
|
||||||
|
|
||||||
- name: Run tests with pytest
|
|
||||||
run: |
|
|
||||||
python -m pytest tests/ -v
|
|
||||||
|
|
@ -6,18 +6,7 @@
|
||||||
|
|
||||||
# Allow specific file types and root files
|
# Allow specific file types and root files
|
||||||
!*.py
|
!*.py
|
||||||
!*.sh
|
!*.md
|
||||||
|
!*.png
|
||||||
# Allow GitHub files
|
!LICENSE
|
||||||
!/.github/**
|
!pyproject.toml
|
||||||
|
|
||||||
# Allow root files
|
|
||||||
!/.gitattributes
|
|
||||||
!/.dockerignore
|
|
||||||
!/Dockerfile
|
|
||||||
!/docker-compose.yml
|
|
||||||
!/assets/**
|
|
||||||
!/CONTRIBUTING.md
|
|
||||||
!/LICENSE
|
|
||||||
!/pyproject.toml
|
|
||||||
!/README.md
|
|
||||||
100
CONTRIBUTING.md
100
CONTRIBUTING.md
|
|
@ -1,100 +0,0 @@
|
||||||
# Contributing to AstrAI
|
|
||||||
|
|
||||||
Thank you for your interest in contributing! This document provides step-by-step guidelines.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/your-username/AstrAI.git
|
|
||||||
cd AstrAI
|
|
||||||
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Before You Commit
|
|
||||||
|
|
||||||
Run the following checks **in order** — CI will reject if any fail.
|
|
||||||
|
|
||||||
### 1. Format
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ruff format .
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note**: `ruff format` may rename parameters (e.g. `mask` → `attn_mask`).
|
|
||||||
> Always review the diff after formatting.
|
|
||||||
|
|
||||||
### 2. Import sorting
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ruff check . --select I
|
|
||||||
```
|
|
||||||
|
|
||||||
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ruff check . --select I --fix .
|
|
||||||
ruff format . # re-format after fix
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Run tests
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -u -m pytest tests/ -v
|
|
||||||
```
|
|
||||||
|
|
||||||
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
|
|
||||||
|
|
||||||
### 4. (Optional) Full pre-commit check
|
|
||||||
|
|
||||||
If you have Git Bash available:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash scripts/pre_commit.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
This runs format check, import sort check, and tests in one go.
|
|
||||||
|
|
||||||
## Commit Style
|
|
||||||
|
|
||||||
```
|
|
||||||
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
|
|
||||||
|
|
||||||
- bullet point body (each ~60 chars)
|
|
||||||
```
|
|
||||||
|
|
||||||
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
|
|
||||||
- **Subject line** ends with no period.
|
|
||||||
- **Body** uses bullet points starting with `-`.
|
|
||||||
- No `(scope)` parentheses.
|
|
||||||
|
|
||||||
## Common Issues
|
|
||||||
|
|
||||||
| Problem | Cause | Fix |
|
|
||||||
|---------|-------|-----|
|
|
||||||
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
|
|
||||||
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
|
|
||||||
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
|
|
||||||
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
|
|
||||||
|
|
||||||
## Submitting Changes
|
|
||||||
|
|
||||||
1. Fork the repo.
|
|
||||||
2. Create a feature branch: `git checkout -b feat/my-feature`
|
|
||||||
3. Make changes following the steps above.
|
|
||||||
4. Commit with the commit style above.
|
|
||||||
5. Push: `git push origin feat/my-feature`
|
|
||||||
6. Open a Pull Request against `main`.
|
|
||||||
|
|
||||||
## Code Review
|
|
||||||
|
|
||||||
- All PRs are reviewed. We may request changes.
|
|
||||||
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
|
|
||||||
- Ensure all tests pass.
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
|
|
||||||
55
Dockerfile
55
Dockerfile
|
|
@ -1,55 +0,0 @@
|
||||||
# AstrAI Dockerfile - Multi-stage Build (Optimized)
|
|
||||||
|
|
||||||
# Build stage - use base image with minimal build tools
|
|
||||||
FROM ubuntu:24.04 AS builder
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Install Python 3.12 and minimal build dependencies
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
||||||
python3.12 \
|
|
||||||
python3.12-dev \
|
|
||||||
python3.12-venv \
|
|
||||||
gcc \
|
|
||||||
g++ \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Create isolated virtual environment
|
|
||||||
RUN python3.12 -m venv --copies /opt/venv
|
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
|
||||||
|
|
||||||
# Copy source code and install (deps read from pyproject.toml)
|
|
||||||
COPY astrai/ ./astrai/
|
|
||||||
COPY pyproject.toml .
|
|
||||||
RUN pip install --no-cache-dir --upgrade pip \
|
|
||||||
&& pip install --no-cache-dir . \
|
|
||||||
--extra-index-url https://download.pytorch.org/whl/cu126
|
|
||||||
|
|
||||||
# Production stage
|
|
||||||
FROM ubuntu:24.04 AS production
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Install Python 3.12 runtime and healthcheck dependency
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
||||||
python3.12 \
|
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy virtual environment from builder
|
|
||||||
COPY --from=builder /opt/venv /opt/venv
|
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
|
||||||
|
|
||||||
# Copy application code
|
|
||||||
COPY astrai/ ./astrai/
|
|
||||||
COPY scripts/ ./scripts/
|
|
||||||
COPY assets/ ./assets/
|
|
||||||
COPY pyproject.toml .
|
|
||||||
COPY README.md .
|
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN useradd -m astrai && chown -R astrai:astrai /app
|
|
||||||
USER astrai
|
|
||||||
|
|
||||||
ENV PYTHONUNBUFFERED=1 \
|
|
||||||
PYTHONDONTWRITEBYTECODE=1
|
|
||||||
453
README.md
453
README.md
|
|
@ -1,255 +1,286 @@
|
||||||
<div align="center">
|

|
||||||
|
|
||||||
<img src="assets/images/logo.png" width="auto" alt="Logo">
|
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; text-align: center; font-size: 16px; font-weight: bold; margin-top: 50px;">
|
||||||
<p>
|
|
||||||
<strong>A lightweight Transformer training & inference framework</strong>
|
<div>
|
||||||
</p>
|
<a href="#english" style="text-decoration: none; margin: 0 10px; color: blue;">English</a> |
|
||||||
|
<a href="#chinese" style="text-decoration: none; margin: 0 10px; color: blue;">中文</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h1 style="margin: 20px 0 0 0; font-size: 2.5em; font-weight: bold;">KHAOSZ </h1>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div align="center">
|
<h2 id="english">English Version</h2>
|
||||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
|
||||||
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
|
|
||||||
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
|
|
||||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
|
|
||||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
|
|
||||||
</div>
|
|
||||||
<br>
|
|
||||||
|
|
||||||
<div align="center">
|
A training and inference framework for autoregressive Transformer language models.
|
||||||
<a href="#english">English</a> •
|
|
||||||
<a href="assets/docs/README-zh-CN.md">中文</a> •
|
|
||||||
<a href="https://github.com/ViperEkura/AstrAI/issues">Issue Tracker</a> •
|
|
||||||
<a href="https://github.com/ViperEkura/AstrAI/discussions">Discussions</a> •
|
|
||||||
<a href="https://huggingface.co/ViperEk/">HuggingFace</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<br>
|
**Model Download Options (choose one):**
|
||||||
|
|
||||||
## 📖 Table of Contents
|
1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) and check **Files and versions**
|
||||||
|
2. Run `scripts/download.py` to download model parameters
|
||||||
|
|
||||||
- [Features](#features)
|
**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
||||||
- [Quick Start](#quick-start)
|
|
||||||
- [Documentation](#documentation)
|
|
||||||
- [Contributing](#contributing)
|
|
||||||
- [Community](#community)
|
|
||||||
- [License](#license)
|
|
||||||
|
|
||||||
---
|
For training data sources, please refer to the **Model Card** section on the HuggingFace download page.
|
||||||
|
|
||||||
<a id="english"></a>
|
**License:** The code follows the GPL-3.0 license. Please provide attribution when using it.
|
||||||
## English
|
|
||||||
|
|
||||||
### Features
|
- **📊 Device Selection:** Uses CUDA for training by default
|
||||||
|
- **🌐 Performance Optimization:** Enable `dtype=torch.bfloat16` to accelerate training and reduce memory usage. Ensure your hardware supports this feature
|
||||||
|
- **🤖 Language Support:** The model supports training in Chinese and English. Since the BBPE tokenizer hasn't been trained on multilingual text, OOV (Out-of-Vocabulary) issues are minimal for Chinese and English, but may exist for other languages
|
||||||
|
|
||||||
- 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization.
|
|
||||||
- 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures.
|
|
||||||
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
|
||||||
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
|
||||||
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
|
||||||
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
|
|
||||||
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
|
|
||||||
|
|
||||||
### Quick Start
|
### 📌 Training Guide
|
||||||
|
|
||||||
#### Installation
|
To train this Transformer model, follow these steps:
|
||||||
|
|
||||||
|
**(1). Prepare the Dataset:**
|
||||||
|
|
||||||
|
Place the dataset in the specified root directory. This system uses the BBPE tokenizer for tokenization and requires training with pre-tokenized segments (stored as *.h5 format files).
|
||||||
|
|
||||||
|
**(2). Install Dependencies:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ViperEkura/AstrAI.git
|
|
||||||
cd AstrAI
|
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
For development dependencies:
|
**(3). Run the Training Script:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e ".[dev]"
|
python train.py \
|
||||||
|
--train_type=train_type[seq, sft, dpo] \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/param_path \
|
||||||
|
--n_epoch=5 \
|
||||||
|
--batch_size=8 \
|
||||||
|
--max_lr=2e-4 \
|
||||||
|
--checkpoint_interval=10000 \
|
||||||
|
--checkpoint_dir=checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Download Pre-trained Model
|
**Parameter Explanation:**
|
||||||
|
- `--train_type`: Training type (seq, sft, dpo)
|
||||||
|
- `--data_root_path`: Dataset root directory
|
||||||
|
- `--param_path`: Path to model training parameters
|
||||||
|
- `--n_epoch`: Total number of training epochs
|
||||||
|
- `--batch_size`: Batch size
|
||||||
|
- `--accumulation_steps`: Number of batches per training step
|
||||||
|
- `--warmup_steps`: Warmup steps
|
||||||
|
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
||||||
|
- `--checkpoint_interval`: Checkpoint saving interval
|
||||||
|
- `--checkpoint_dir`: Checkpoint saving directory
|
||||||
|
- `--resume_dir`: Resume training from specified path
|
||||||
|
|
||||||
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
|
|
||||||
|
|
||||||
|
### 👉 Usage Guide
|
||||||
|
|
||||||
|
**(1). Chat with the Model:**
|
||||||
|
|
||||||
|
Open `chat.py` or use the streaming/non-streaming interfaces:
|
||||||
|
|
||||||
|
**Streaming Output:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response_size = 0
|
||||||
|
for response, history in model.stream_generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.85,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
):
|
||||||
|
print(response[response_size:], end="")
|
||||||
|
response_size = len(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Non-streaming Output:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response = model.generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.85,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
**(2). Retrieval-Augmented Generation (RAG):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
retrieved_content = model.retrieve_generate(
|
||||||
|
query=query,
|
||||||
|
retrieve_top_k=5,
|
||||||
|
temperature=0.6,
|
||||||
|
top_k=30,
|
||||||
|
top_p=0.95
|
||||||
|
)
|
||||||
|
print(retrieved_content)
|
||||||
|
```
|
||||||
|
|
||||||
|
<h2 id="chinese">中文版本</h2>
|
||||||
|
这是一个支持基于自回归模式的 Transfomer 语言模型训练以及推理框架
|
||||||
|
|
||||||
|
**模型下载选项(任选其一):**
|
||||||
|
|
||||||
|
1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
|
||||||
|
2. 运行 `scripts/download.py` 下载模型参数
|
||||||
|
|
||||||
|
**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
||||||
|
|
||||||
|
训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
|
||||||
|
|
||||||
|
**许可证:** 代码遵循 GPL-3.0 协议,使用时请注明出处。
|
||||||
|
|
||||||
|
- **📊 设备选择:** 默认使用 CUDA 进行训练
|
||||||
|
- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
|
||||||
|
- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV(未登录词)问题较少,其他语言可能存在 OOV 问题
|
||||||
|
|
||||||
|
|
||||||
|
### 📌 训练指南
|
||||||
|
|
||||||
|
要训练该 Transformer 模型,请按照以下步骤操作:
|
||||||
|
|
||||||
|
**(1). 准备数据集:**
|
||||||
|
|
||||||
|
将数据集放置在指定的根目录下, 本系统采用 BBPE 分词器进行分词,并且要求使用已经经过分词的 token 分段训练(分段存储为 *.h5 格式)
|
||||||
|
|
||||||
|
**(2). 安装依赖:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/demo/download.py
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
|
**(3). 运行训练脚本:**
|
||||||
|
|
||||||
#### Train a Model
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
python train.py \
|
||||||
|
--train_type=train_type[seq, sft, dpo] \
|
||||||
nohup python scripts/tools/train.py \
|
--data_root_path=/path/to/dataset \
|
||||||
--nprocs=4 \
|
--param_path=/path/to/param_path \
|
||||||
--parallel_mode=ddp \
|
--n_epoch=5 \
|
||||||
--train_type=seq \
|
--batch_size=8 \
|
||||||
--data_root_path=/path/to/dataset \
|
--max_lr=2e-4 \
|
||||||
--param_path=/path/to/model \
|
--checkpoint_interval=10000 \
|
||||||
--batch_per_device=4 \
|
--checkpoint_dir=checkpoints
|
||||||
--grad_accum_steps=8 \
|
|
||||||
--warmup_ratio=0.05 \
|
|
||||||
--max_lr=1e-4 \
|
|
||||||
--max_grad_norm=1.0 \
|
|
||||||
--adamw_beta1=0.9 \
|
|
||||||
--adamw_beta2=0.95 \
|
|
||||||
--adamw_weight_decay=0.01 \
|
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
**参数说明:**
|
||||||
|
- `--train_type`: 训练类型(seq, sft, dpo)
|
||||||
|
- `--data_root_path`: 数据集根目录
|
||||||
|
- `--param_path`: 模型训练参数路径
|
||||||
|
- `--n_epoch`: 总训练轮数
|
||||||
|
- `--batch_size`: 批量大小
|
||||||
|
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
||||||
|
- `--warmup_steps`: 预热步数(warmup steps)
|
||||||
|
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
||||||
|
- `--checkpoint_interval`: 检查点保存间隔
|
||||||
|
- `--checkpoint_dir`: 检查点保存目录
|
||||||
|
- `--resume_dir`: 从指定路径恢复训练
|
||||||
|
|
||||||
#### Generate Text
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/tools/generate.py \
|
### 👉 使用指南
|
||||||
--param_path /path/to/model \
|
|
||||||
--input_json_file /path/to/input.jsonl \
|
**(1). 与模型对话:**
|
||||||
--output_json_file /path/to/output.jsonl
|
|
||||||
|
打开 `chat.py` 或使用流式/非流式接口:
|
||||||
|
|
||||||
|
**流式输出:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response_size = 0
|
||||||
|
for response, history in model.stream_generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.85,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
):
|
||||||
|
print(response[response_size:], end="")
|
||||||
|
response_size = len(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
**非流式输出:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
Build and run with Docker (recommended for GPU environments):
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
```bash
|
while True:
|
||||||
# Build image
|
query = input(">> ")
|
||||||
docker build -t astrai:latest .
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
# Run with GPU support
|
response = model.generate(
|
||||||
docker run --gpus all -it astrai:latest
|
query=query,
|
||||||
|
history=history,
|
||||||
# Run with specific GPUs
|
temperature=0.85,
|
||||||
docker run --gpus '"device=0,1"' -it astrai:latest
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
# Run inference server
|
)
|
||||||
docker run --gpus all -p 8000:8000 astrai:latest \
|
print(response)
|
||||||
python -m scripts.tools.server --port 8000 --device cuda
|
|
||||||
|
|
||||||
# Run with volume mount for data
|
|
||||||
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
|
||||||
|
|
||||||
# Docker Compose (GPU, default)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Docker Compose (CPU only)
|
|
||||||
docker compose --profile cpu up -d
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
**(2). 基于检索的生成(RAG):**
|
||||||
|
|
||||||
#### Start HTTP Server
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
```bash
|
retrieved_content = model.retrieve_generate(
|
||||||
python -m scripts.tools.server --port 8000 --device cuda
|
query=query,
|
||||||
|
retrieve_top_k=5,
|
||||||
|
temperature=0.6,
|
||||||
|
top_k=30,
|
||||||
|
top_p=0.95
|
||||||
|
)
|
||||||
|
print(retrieved_content)
|
||||||
```
|
```
|
||||||
|
|
||||||
Make requests:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# OpenAI-compatible
|
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"max_tokens": 512
|
|
||||||
}'
|
|
||||||
|
|
||||||
# OpenAI-compatible streaming
|
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [{"role": "user", "content": "Tell a story"}],
|
|
||||||
"stream": true,
|
|
||||||
"max_tokens": 500
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Anthropic-compatible
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"system": "You are a helpful assistant.",
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"max_tokens": 512
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Anthropic-compatible streaming with stop sequences
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"messages": [{"role": "user", "content": "Write a story"}],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"stream": true,
|
|
||||||
"stop_sequences": ["The end"]
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Health check
|
|
||||||
curl http://localhost:8000/health
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Demo
|
|
||||||
|
|
||||||
Check out the demos in the `scripts/demo/` folder:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Download pre‑processed data (required before running demos)
|
|
||||||
python scripts/demo/download.py
|
|
||||||
|
|
||||||
# Interactive streaming chat
|
|
||||||
python scripts/demo/stream_chat.py
|
|
||||||
|
|
||||||
# Batch generation
|
|
||||||
python scripts/demo/generate_batch.py
|
|
||||||
|
|
||||||
# Auto‑regressive generation
|
|
||||||
python scripts/demo/generate_ar.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
|
|
||||||
| Document | Description |
|
|
||||||
|----------|-------------|
|
|
||||||
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
|
|
||||||
| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns |
|
|
||||||
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
|
||||||
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
|
||||||
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
|
||||||
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
|
|
||||||
|
|
||||||
### Contributing
|
|
||||||
|
|
||||||
We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
|
|
||||||
|
|
||||||
1. Fork the repository.
|
|
||||||
2. Create a feature branch.
|
|
||||||
3. Commit your changes.
|
|
||||||
4. Open a Pull Request.
|
|
||||||
|
|
||||||
For major changes, please open an issue first to discuss what you would like to change.
|
|
||||||
|
|
||||||
### Community
|
|
||||||
|
|
||||||
- **GitHub Issues**: [Issue Tracker](https://github.com/ViperEkura/AstrAI/issues)
|
|
||||||
- **Discussions**: [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions)
|
|
||||||
- **HuggingFace**: [Model Hub](https://huggingface.co/ViperEk)
|
|
||||||
|
|
||||||
### License
|
|
||||||
|
|
||||||
This project is licensed under the [GPL-3.0 License](LICENSE).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<em>A lightweight Transformer framework designed for both high performance and ease of use.</em>
|
|
||||||
</div>
|
|
||||||
|
|
@ -1,261 +0,0 @@
|
||||||
<div align="center">
|
|
||||||
|
|
||||||
<img src="../images/logo.png" width="auto" alt="Logo">
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<a href="../../README.md">English</a> •
|
|
||||||
<a href="#chinese">中文</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p>
|
|
||||||
<strong>轻量级 Transformer 训练与推理框架</strong>
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
|
||||||
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
|
|
||||||
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
|
|
||||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
|
|
||||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<a href="../../README.md">English</a> •
|
|
||||||
<a href="#chinese">中文</a> •
|
|
||||||
<a href="https://github.com/ViperEkura/AstrAI/issues">问题追踪</a> •
|
|
||||||
<a href="https://github.com/ViperEkura/AstrAI/discussions">讨论区</a> •
|
|
||||||
<a href="https://huggingface.co/ViperEk">HuggingFace</a>
|
|
||||||
</div>
|
|
||||||
<br>
|
|
||||||
|
|
||||||
## 📖 目录
|
|
||||||
|
|
||||||
- [特性](#特性)
|
|
||||||
- [快速开始](#快速开始)
|
|
||||||
- [文档](#文档)
|
|
||||||
- [贡献](#贡献)
|
|
||||||
- [社区](#社区)
|
|
||||||
- [许可证](#许可证)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<a id="chinese"></a>
|
|
||||||
## 中文
|
|
||||||
|
|
||||||
### 特性
|
|
||||||
|
|
||||||
- 🚀 **高性能**: 训练与推理双向优化,高效并行。
|
|
||||||
- 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。
|
|
||||||
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
|
||||||
- 📦 **轻量**: 依赖少,部署简单。
|
|
||||||
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
|
||||||
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
|
|
||||||
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
|
|
||||||
|
|
||||||
### 快速开始
|
|
||||||
|
|
||||||
#### 安装
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/ViperEkura/AstrAI.git
|
|
||||||
cd AstrAI
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
安装开发依赖:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -e ".[dev]"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 下载预训练模型
|
|
||||||
|
|
||||||
下载预训练模型权重(1B 双语检查点)到 `params/` 目录:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/demo/download.py
|
|
||||||
```
|
|
||||||
|
|
||||||
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。
|
|
||||||
|
|
||||||
#### 训练模型
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
|
||||||
--nprocs=4 \
|
|
||||||
--parallel_mode=ddp \
|
|
||||||
--train_type=seq \
|
|
||||||
--data_root_path=/path/to/dataset \
|
|
||||||
--param_path=/path/to/model \
|
|
||||||
--batch_per_device=4 \
|
|
||||||
--grad_accum_steps=8 \
|
|
||||||
--warmup_ratio=0.05 \
|
|
||||||
--max_lr=1e-4 \
|
|
||||||
--max_grad_norm=1.0 \
|
|
||||||
--adamw_beta1=0.9 \
|
|
||||||
--adamw_beta2=0.95 \
|
|
||||||
--adamw_weight_decay=0.01 \
|
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md)。
|
|
||||||
|
|
||||||
#### 文本生成
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/tools/generate.py \
|
|
||||||
--param_path /path/to/model \
|
|
||||||
--input_json_file /path/to/input.jsonl \
|
|
||||||
--output_json_file /path/to/output.jsonl
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Docker
|
|
||||||
|
|
||||||
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 构建镜像
|
|
||||||
docker build -t astrai:latest .
|
|
||||||
|
|
||||||
# 启用 GPU 运行
|
|
||||||
docker run --gpus all -it astrai:latest
|
|
||||||
|
|
||||||
# 指定特定 GPU
|
|
||||||
docker run --gpus '"device=0,1"' -it astrai:latest
|
|
||||||
|
|
||||||
# 运行推理服务
|
|
||||||
docker run --gpus all -p 8000:8000 astrai:latest \
|
|
||||||
python -m scripts.tools.server --port 8000 --device cuda
|
|
||||||
|
|
||||||
# 挂载数据卷
|
|
||||||
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
|
||||||
|
|
||||||
# Docker Compose(GPU,默认)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Docker Compose(仅 CPU)
|
|
||||||
docker compose --profile cpu up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`。
|
|
||||||
|
|
||||||
#### 启动 HTTP 服务
|
|
||||||
|
|
||||||
启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m scripts.tools.server --port 8000 --device cuda
|
|
||||||
```
|
|
||||||
|
|
||||||
发起请求:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# OpenAI 兼容
|
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [{"role": "user", "content": "你好"}],
|
|
||||||
"max_tokens": 512
|
|
||||||
}'
|
|
||||||
|
|
||||||
# OpenAI 兼容流式
|
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [{"role": "user", "content": "讲个故事"}],
|
|
||||||
"stream": true,
|
|
||||||
"max_tokens": 500
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Anthropic 兼容
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"system": "你是一个乐于助人的助手。",
|
|
||||||
"messages": [{"role": "user", "content": "你好"}],
|
|
||||||
"max_tokens": 512
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Anthropic 兼容流式并设置停止序列
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"messages": [{"role": "user", "content": "写个故事"}],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"stream": true,
|
|
||||||
"stop_sequences": ["结束"]
|
|
||||||
}'
|
|
||||||
|
|
||||||
# 健康检查
|
|
||||||
curl http://localhost:8000/health
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 演示
|
|
||||||
|
|
||||||
查看 `scripts/demo/` 文件夹中的演示:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 下载预处理数据(运行演示前必需)
|
|
||||||
python scripts/demo/download.py
|
|
||||||
|
|
||||||
# 交互式流式聊天
|
|
||||||
python scripts/demo/stream_chat.py
|
|
||||||
|
|
||||||
# 批量生成
|
|
||||||
python scripts/demo/generate_batch.py
|
|
||||||
|
|
||||||
# 自回归生成
|
|
||||||
python scripts/demo/generate_ar.py
|
|
||||||
```
|
|
||||||
|
|
||||||
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
|
|
||||||
|
|
||||||
### 文档
|
|
||||||
|
|
||||||
| 文档 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| [参数说明](./params.md) | 训练与推理参数配置 |
|
|
||||||
| [架构文档](./architecture.md) | 系统架构、类图与设计模式 |
|
|
||||||
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
|
||||||
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
|
||||||
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
|
||||||
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
|
|
||||||
|
|
||||||
### 贡献
|
|
||||||
|
|
||||||
我们欢迎贡献!请参阅[贡献指南](../../CONTRIBUTING.md)了解详情。
|
|
||||||
|
|
||||||
1. Fork 本仓库。
|
|
||||||
2. 创建功能分支。
|
|
||||||
3. 提交更改。
|
|
||||||
4. 发起 Pull Request。
|
|
||||||
|
|
||||||
重大更改请先开 issue 讨论。
|
|
||||||
|
|
||||||
### 社区
|
|
||||||
|
|
||||||
- **GitHub Issues**: [问题追踪](https://github.com/ViperEkura/AstrAI/issues)
|
|
||||||
- **Discussions**: [GitHub 讨论区](https://github.com/ViperEkura/AstrAI/discussions)
|
|
||||||
- **HuggingFace**: [模型中心](https://huggingface.co/ViperEk)
|
|
||||||
|
|
||||||
### 许可证
|
|
||||||
|
|
||||||
本项目采用 [GPL-3.0 许可证](../../LICENSE)。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<em>专为高性能与易用性设计的轻量级 Transformer 框架。</em>
|
|
||||||
</div>
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,64 +0,0 @@
|
||||||
# Data Flow
|
|
||||||
|
|
||||||
This document describes the data pipeline: from raw text to model input tensors.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
```
|
|
||||||
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
|
|
||||||
```
|
|
||||||
|
|
||||||
## Data Preparation
|
|
||||||
|
|
||||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
|
||||||
|
|
||||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
|
||||||
|
|
||||||
```
|
|
||||||
StoreFactory.create("h5") → H5Store
|
|
||||||
StoreFactory.create("bin") → MmapStore
|
|
||||||
```
|
|
||||||
|
|
||||||
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
|
||||||
|
|
||||||
## Data Keys by Training Type
|
|
||||||
|
|
||||||
| Type | Storage Keys |
|
|
||||||
|------|-------------|
|
|
||||||
| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) |
|
|
||||||
| `sft` | `sequence`, `loss_mask` |
|
|
||||||
| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` |
|
|
||||||
| `grpo` | `prompts`, `responses`, `masks`, `rewards` |
|
|
||||||
|
|
||||||
## Dataset Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
|
|
||||||
→ BaseDataset.load(load_path, storage_type=None)
|
|
||||||
→ detect_format(load_path)
|
|
||||||
→ StoreFactory.create(storage_type)
|
|
||||||
→ Store.load(load_path)
|
|
||||||
→ H5Store._normalize() / MmapStore._normalize()
|
|
||||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
|
||||||
→ BaseDataset.__getitem__(idx)
|
|
||||||
→ get_index(idx) → [begin, end)
|
|
||||||
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
|
|
||||||
```
|
|
||||||
|
|
||||||
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
|
|
||||||
|
|
||||||
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
|
|
||||||
|
|
||||||
## Sampler
|
|
||||||
|
|
||||||
`ResumableDistributedSampler` supports checkpoint-aware distributed sampling:
|
|
||||||
|
|
||||||
- Tracks `start_epoch` / `start_iter` for resume
|
|
||||||
- Shuffle via `torch.Generator(seed + epoch)`
|
|
||||||
- Per-replica index slicing for DDP
|
|
||||||
|
|
||||||
## DataLoader
|
|
||||||
|
|
||||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
# Inference
|
|
||||||
|
|
||||||
## KV Cache
|
|
||||||
|
|
||||||
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
|
|
||||||
|
|
||||||
$$
|
|
||||||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
|
|
||||||
$$
|
|
||||||
|
|
||||||
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
|
|
||||||
|
|
||||||
## KVCache System
|
|
||||||
|
|
||||||
Six classes (plus two helpers) working together:
|
|
||||||
|
|
||||||
```
|
|
||||||
KVCache (facade)
|
|
||||||
├── PagePool orchestrates page allocation + prefix matching
|
|
||||||
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
|
||||||
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
|
||||||
├── TaskTable maps task_id → page_table + cached token count
|
|
||||||
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
|
||||||
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
|
||||||
```
|
|
||||||
|
|
||||||
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
|
||||||
|
|
||||||
## Continuous Batching
|
|
||||||
|
|
||||||
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
|
|
||||||
|
|
||||||
```
|
|
||||||
1. Cleanup → Remove finished tasks, free KV pages
|
|
||||||
2. Refill → Pop from waiting_queue, task_alloc pages, activate
|
|
||||||
3. Prefill → Group by (prompt_len, start_pos), run full forward
|
|
||||||
4. Decode → Pick largest same-position group, single-token forward
|
|
||||||
```
|
|
||||||
|
|
||||||
## Sampling (Strategy Pattern)
|
|
||||||
|
|
||||||
```
|
|
||||||
BaseSamplingStrategy (ABC)
|
|
||||||
├── TemperatureStrategy
|
|
||||||
├── TopKStrategy
|
|
||||||
├── TopPStrategy
|
|
||||||
└── SamplingPipeline
|
|
||||||
```
|
|
||||||
|
|
||||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
|
||||||
`sample()` is a convenience shortcut for one-shot usage.
|
|
||||||
|
|
||||||
## Protocol Handlers (Strategy Pattern)
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ProtocolHandler: # concrete orchestrator
|
|
||||||
def __init__(self, request, engine, builder): ...
|
|
||||||
async def handle(self):
|
|
||||||
prompt, ctx, stops = builder.prepare(request, engine)
|
|
||||||
agen = engine.generate_async(prompt, ...)
|
|
||||||
if stream: self._handle_stream(agen, ctx, stops)
|
|
||||||
else: return await self._handle_non_stream(agen, ctx, stops)
|
|
||||||
```
|
|
||||||
|
|
||||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
|
||||||
|
|
||||||
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
|
||||||
|
|
||||||
Adding a protocol = one builder file, no handler subclassing needed.
|
|
||||||
|
|
||||||
## Engine & GenerateResult
|
|
||||||
|
|
||||||
```
|
|
||||||
InferenceEngine
|
|
||||||
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
|
||||||
├── generate_with_request(req) → same
|
|
||||||
├── generate_async(prompt, ...) → AsyncGenerator
|
|
||||||
├── get_stats() → Dict
|
|
||||||
└── shutdown()
|
|
||||||
```
|
|
||||||
|
|
||||||
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
|
||||||
|
|
||||||
## HTTP API
|
|
||||||
|
|
||||||
```
|
|
||||||
POST /v1/chat/completions OpenAI
|
|
||||||
POST /v1/messages Anthropic
|
|
||||||
GET /health {"status":"ok","model_loaded":true}
|
|
||||||
GET /stats scheduler statistics
|
|
||||||
```
|
|
||||||
|
|
||||||
### OpenAI
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Response:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"id": "chatcmpl-abc123",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1717000000,
|
|
||||||
"model": "astrai",
|
|
||||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
|
||||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
|
||||||
|
|
||||||
### Anthropic
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
|
||||||
|
|
||||||
### GenerationRequest Parameters
|
|
||||||
|
|
||||||
| Param | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
|
||||||
| `top_k` | int | 50 | Top-k count |
|
|
||||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
|
|
||||||
| `max_tokens` | Optional[int] | None | Max generation length |
|
|
||||||
| `stream` | bool | False | Stream output |
|
|
||||||
|
|
||||||
## Engine API
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Non-streaming
|
|
||||||
engine.generate("Hello", stream=False) # -> str
|
|
||||||
engine.generate(["A", "B"], stream=False) # -> List[str]
|
|
||||||
|
|
||||||
# Streaming
|
|
||||||
engine.generate("Hello", stream=True) # -> Generator[str]
|
|
||||||
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
|
||||||
|
|
||||||
# Async
|
|
||||||
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
|
|
||||||
print(token)
|
|
||||||
```
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
## 模型介绍
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 1. 模型搭建
|
||||||
|
|
||||||
|
本模型采用Transformer架构, 使用GQA(q_head=24, kv_head=4) 机制,相较于传统的MHA可以节省KV cache 的显存占用(但是目前没有做KV cache),通过堆叠24层Transformer实现模型的搭建, 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文(即已经出现的tokens序列),计算出下一个可能的token及其对应的概率。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 1. 自回归
|
||||||
|
|
||||||
|
假设我们有一个句子被拆分成如下tokens列表:
|
||||||
|
|
||||||
|
```
|
||||||
|
["你好", "," "今天", "天气"]
|
||||||
|
```
|
||||||
|
|
||||||
|
接下来,模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出,比如:
|
||||||
|
|
||||||
|
```
|
||||||
|
-> {"token": "不错", "probability": 0.4}
|
||||||
|
-> {"token": "晴朗", "probability": 0.2}
|
||||||
|
-> ......
|
||||||
|
```
|
||||||
|
|
||||||
|
这里,“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens,并且给出了每个token成为下一个token的可能性大小。
|
||||||
|
|
||||||
|
之后,我们通过采样(通过top_k, top_p, temperature参数调整采样后的结果)得到下一个token并且将下一个token加入序列作为输入
|
||||||
|
|
||||||
|
```
|
||||||
|
["你好", "," "今天", "天气", "不错"]
|
||||||
|
```
|
||||||
|
|
||||||
|
之后都是在重复这个流程, 直到遇到控制流程结束的token(<|end_of_seqence|>)模型停止处理(一般模型都会设置控制token, 不然模型会一直输出到显存爆炸)。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 2. 因果掩码
|
||||||
|
|
||||||
|
transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len], 输出为[bsz, seq_len,n_dim], 为了实现预测下一个token, 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置, 在训练的时候我们也采用错开一个位置的方法
|
||||||
|
|
||||||
|
```
|
||||||
|
sequence : [[1, 2, 3, 4, 5, 6]]
|
||||||
|
input_ids: [[1, 2, 3, 4, 5]]
|
||||||
|
target_ids: [[2, 3, 4, 5, 6]]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
注意力得分计算的公式为
|
||||||
|
|
||||||
|
|
||||||
|
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
|
||||||
|
$$ s_{ij} := s_{ij} + mask_{ij} $$
|
||||||
|
|
||||||
|
|
||||||
|
其中注意力得分代表了模型对两个token之间相似程度的关注程度
|
||||||
|
|
||||||
|
对于decoder only结构的模型, 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码,我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵,对于长度为n的序列,它的形状是[n, n]。下面以一个长度为5的序列为例,展示如何创建这样的因果掩码矩阵:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[0, -inf, -inf, -inf, -inf],
|
||||||
|
[0, 0, -inf, -inf, -inf],
|
||||||
|
[0, 0, 0, -inf, -inf],
|
||||||
|
[0, 0, 0, 0, -inf],
|
||||||
|
[0, 0, 0, 0, 0]]
|
||||||
|
```
|
||||||
|
|
||||||
|
在这个矩阵中,0表示可以注意到的位置,而-inf表示应该被掩盖(即不应注意到)的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0, 也就是模型不能看到未来的信息
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 3. 旋转位置编码
|
||||||
|
|
||||||
|
旋转位置编码(Rotary Position Embedding, RoPE)是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码(如正弦和余弦函数的位置编码)不同,RoPE通过将位置信息直接嵌入到查询(Query, Q)和键(Key, K)向量中来实现,使得模型能够更自然地处理序列中的相对位置关系。
|
||||||
|
|
||||||
|
|
||||||
|
$$ q_i = R_i W_q x_i $$
|
||||||
|
$$ k_j = R_j W_k x_j $$
|
||||||
|
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||||||
|
|
||||||
|
其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
## kv_cache 实现
|
||||||
|
|
||||||
|
根据注意力的计算公式
|
||||||
|
|
||||||
|
$$
|
||||||
|
\begin{align*}
|
||||||
|
o_i &= \sum_j s_{ij} v_{j} \newline
|
||||||
|
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
|
||||||
|
\end{align*}
|
||||||
|
$$
|
||||||
|
|
||||||
|
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
|
||||||
|
|
||||||
|
$$
|
||||||
|
\begin{align*}
|
||||||
|
o_n &= \sum_j s_{j}v_{j} \newline
|
||||||
|
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
|
||||||
|
\end{align*}
|
||||||
|
$$
|
||||||
|
|
||||||
|
如果我们把式子展开
|
||||||
|
|
||||||
|
$$
|
||||||
|
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
||||||
|
$$
|
||||||
|
|
||||||
|
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
|
||||||
|
|
@ -1,100 +0,0 @@
|
||||||
# Parameter Documentation
|
|
||||||
|
|
||||||
## Training Parameters
|
|
||||||
|
|
||||||
### Basic Parameters
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
|
||||||
| `--batch_per_device` | Batch size per device | 1 |
|
|
||||||
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
|
||||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
|
||||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
|
||||||
|
|
||||||
### Optimizer (AdamW)
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
|
||||||
|
|
||||||
### Data Loading
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--window_size` | Max input sequence length | model config `max_len` |
|
|
||||||
| `--stride` | Stride for sliding window over sequences | None |
|
|
||||||
| `--random_seed` | Random seed for reproducibility | 3407 |
|
|
||||||
| `--num_workers` | DataLoader worker processes | 4 |
|
|
||||||
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
|
|
||||||
|
|
||||||
### Checkpoint & Resume
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
|
|
||||||
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
|
||||||
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
|
|
||||||
| `--start_batch` | Resume from batch iteration | 0 |
|
|
||||||
|
|
||||||
### Distributed Training
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
|
||||||
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
|
||||||
| `--device_type` | Device type | cuda |
|
|
||||||
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
|
||||||
|
|
||||||
### Strategy-specific
|
|
||||||
|
|
||||||
| Parameter | Description | Default | Used by |
|
|
||||||
|-----------|-------------|---------|---------|
|
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
|
||||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
|
||||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
|
||||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
|
||||||
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
|
||||||
|
|
||||||
### Usage Example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
|
||||||
--nprocs=4 \
|
|
||||||
--parallel_mode=ddp \
|
|
||||||
--train_type=seq \
|
|
||||||
--data_root_path=/path/to/dataset \
|
|
||||||
--param_path=/path/to/model \
|
|
||||||
--batch_per_device=4 \
|
|
||||||
--grad_accum_steps=8 \
|
|
||||||
--warmup_ratio=0.05 \
|
|
||||||
--max_lr=1e-4 \
|
|
||||||
--max_grad_norm=1.0 \
|
|
||||||
--adamw_beta1=0.9 \
|
|
||||||
--adamw_beta2=0.95 \
|
|
||||||
--adamw_weight_decay=0.01 \
|
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-24
|
|
||||||
|
|
@ -1,283 +0,0 @@
|
||||||
# Preprocessing Pipeline
|
|
||||||
|
|
||||||
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
|
|
||||||
|
|
||||||
## Philosophy
|
|
||||||
|
|
||||||
| Component | Responsibility |
|
|
||||||
|-----------|---------------|
|
|
||||||
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
|
||||||
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
|
||||||
|
|
||||||
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### SFT Chat
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"type": "chat",
|
|
||||||
"messages_key": "messages"
|
|
||||||
},
|
|
||||||
"mask": {
|
|
||||||
"system": "mask",
|
|
||||||
"user": "mask",
|
|
||||||
"assistant": "train"
|
|
||||||
},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048,
|
|
||||||
"deduplicate": true
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"domain_key": "source",
|
|
||||||
"storage_format": "bin",
|
|
||||||
"max_tokens_per_shard": 100000000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
|
|
||||||
|
|
||||||
### Instruction Tuning
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"type": "instruction",
|
|
||||||
"prompt_key": "instruction",
|
|
||||||
"response_key": "output"
|
|
||||||
},
|
|
||||||
"mask": {
|
|
||||||
"prompt": "mask",
|
|
||||||
"response": "train"
|
|
||||||
},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"storage_format": "bin"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Mask splits at the prompt/response field boundary.
|
|
||||||
|
|
||||||
### Pretraining
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"type": "text",
|
|
||||||
"text_key": "content"
|
|
||||||
},
|
|
||||||
"mask": {},
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048,
|
|
||||||
"min_chars": 50
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"storage_format": "bin"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
No mask -- train on all tokens.
|
|
||||||
|
|
||||||
### Run
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## Configuration Reference
|
|
||||||
|
|
||||||
### `input`
|
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
|
||||||
|-------|------|----------|---------|-------------|
|
|
||||||
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
|
|
||||||
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
|
|
||||||
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
|
|
||||||
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
|
|
||||||
| `text_key` | string | no | `"text"` | JSON key for text field |
|
|
||||||
|
|
||||||
### `mask`
|
|
||||||
|
|
||||||
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
|
|
||||||
|
|
||||||
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
|
|
||||||
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
|
|
||||||
|
|
||||||
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
|
|
||||||
For instruction mode, keys are `"prompt"` and `"response"`.
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `mask` | dict | `{}` | Role/field to action mapping |
|
|
||||||
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
|
|
||||||
|
|
||||||
### `preprocessing`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
|
|
||||||
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
|
|
||||||
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
|
|
||||||
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
|
|
||||||
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
|
|
||||||
|
|
||||||
### `output`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
|
|
||||||
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
|
|
||||||
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
|
|
||||||
|
|
||||||
## Mask Algorithm
|
|
||||||
|
|
||||||
### Chat Mode (role-span tracking)
|
|
||||||
|
|
||||||
For each message in the `messages` array:
|
|
||||||
|
|
||||||
1. Prepend BOS token (position 0, always masked)
|
|
||||||
2. Render through the chat template for that single message
|
|
||||||
3. Encode the rendered text, record token span `(start, end, role)`
|
|
||||||
4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
|
|
||||||
5. Fill `loss_mask` from the mask rules
|
|
||||||
|
|
||||||
**Multi-turn example**:
|
|
||||||
|
|
||||||
```
|
|
||||||
Data:
|
|
||||||
[system: "You are helpful."]
|
|
||||||
[user: "What is 2+2?"]
|
|
||||||
[assistant: "4"]
|
|
||||||
[user: "What is 3+3?"]
|
|
||||||
[assistant: "6"]
|
|
||||||
|
|
||||||
Config:
|
|
||||||
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
|
|
||||||
|
|
||||||
Result:
|
|
||||||
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
|
|
||||||
mask: 0 0 0 1 0 1
|
|
||||||
```
|
|
||||||
|
|
||||||
Both assistant turns are trained. All system and user tokens are masked.
|
|
||||||
|
|
||||||
### Instruction Mode (field boundary)
|
|
||||||
|
|
||||||
Encode the prompt and response fields independently, then split the mask at the field boundary.
|
|
||||||
|
|
||||||
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
|
|
||||||
- `"prompt": "train", "response": "mask"` -- the reverse
|
|
||||||
|
|
||||||
### Text Mode (no mask)
|
|
||||||
|
|
||||||
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
|
|
||||||
|
|
||||||
## Output Layout
|
|
||||||
|
|
||||||
### Single-Shard (`bin`)
|
|
||||||
|
|
||||||
```
|
|
||||||
output_dir/
|
|
||||||
__default__/ # when domain_key is null
|
|
||||||
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
|
|
||||||
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
|
|
||||||
loss_mask.bin # int64 raw bytes
|
|
||||||
wiki/ # when domain_key="source" and item["source"]="wiki"
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multi-Shard (`bin`)
|
|
||||||
|
|
||||||
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
|
|
||||||
|
|
||||||
```
|
|
||||||
output_dir/
|
|
||||||
__default__/
|
|
||||||
shard_0000/
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
shard_0001/
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
```
|
|
||||||
|
|
||||||
`MmapStore` automatically discovers and merges all shards under the domain directory.
|
|
||||||
|
|
||||||
### H5 Output
|
|
||||||
|
|
||||||
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
|
|
||||||
|
|
||||||
```
|
|
||||||
output_dir/
|
|
||||||
__default__/
|
|
||||||
data_0000.h5 # each H5 contains key→dataset groups
|
|
||||||
data_0001.h5
|
|
||||||
wiki/
|
|
||||||
data_0000.h5
|
|
||||||
```
|
|
||||||
|
|
||||||
## Python API Usage
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
|
||||||
|
|
||||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
|
||||||
Pipeline(
|
|
||||||
config,
|
|
||||||
["data_part1.jsonl", "data_part2.jsonl"],
|
|
||||||
output_dir="output/",
|
|
||||||
tokenizer_path="params"
|
|
||||||
).run()
|
|
||||||
```
|
|
||||||
|
|
||||||
Or from the CLI:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## Extension
|
|
||||||
|
|
||||||
Register a custom builder for new formats:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("my_format")
|
|
||||||
class MyFormatBuilder(BaseMaskBuilder):
|
|
||||||
def build(self, item: dict, config, tokenizer) -> dict | None:
|
|
||||||
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
|
|
||||||
# Return None to skip this item
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
Then set `"input": {"type": "my_format"}` in your config.
|
|
||||||
|
|
||||||
## Compared to Old Pipeline
|
|
||||||
|
|
||||||
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|
|
||||||
|---|---|
|
|
||||||
| Configured via constructor arguments | Configured via JSON file |
|
|
||||||
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
|
|
||||||
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
|
|
||||||
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
|
|
||||||
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
|
||||||
|
|
@ -1,201 +0,0 @@
|
||||||
# Training
|
|
||||||
|
|
||||||
### Autoregression
|
|
||||||
|
|
||||||
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
|
||||||
|
|
||||||
### Causal Mask
|
|
||||||
|
|
||||||
```
|
|
||||||
sequence : [[1, 2, 3, 4, 5, 6]]
|
|
||||||
input_ids: [[1, 2, 3, 4, 5]]
|
|
||||||
target_ids: [[2, 3, 4, 5, 6]]
|
|
||||||
```
|
|
||||||
|
|
||||||
Lower-triangular mask prevents attending to future positions:
|
|
||||||
|
|
||||||
```
|
|
||||||
[[0, -inf, -inf, -inf, -inf],
|
|
||||||
[0, 0, -inf, -inf, -inf],
|
|
||||||
[0, 0, 0, -inf, -inf],
|
|
||||||
[0, 0, 0, 0, -inf],
|
|
||||||
[0, 0, 0, 0, 0]]
|
|
||||||
```
|
|
||||||
|
|
||||||
### Rotary Position Embedding (RoPE)
|
|
||||||
|
|
||||||
RoPE embeds position into Q/K vectors via complex rotation:
|
|
||||||
|
|
||||||
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
|
|
||||||
|
|
||||||
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
|
|
||||||
|
|
||||||
## Training Loop
|
|
||||||
|
|
||||||
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
|
||||||
|
|
||||||
```
|
|
||||||
on_train_begin
|
|
||||||
model.train()
|
|
||||||
on_epoch_begin
|
|
||||||
for batch in dataloader:
|
|
||||||
on_batch_begin
|
|
||||||
with executor.accumulate(model):
|
|
||||||
loss = strategy.compute_loss(batch)
|
|
||||||
context.loss = loss.item()
|
|
||||||
stand_loss = loss / executor.grad_accum_steps
|
|
||||||
executor.backward(stand_loss)
|
|
||||||
context.iteration += 1
|
|
||||||
on_batch_end
|
|
||||||
|
|
||||||
if executor.sync_gradients:
|
|
||||||
on_optimizer_step
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
if scheduler:
|
|
||||||
scheduler.step()
|
|
||||||
on_epoch_end
|
|
||||||
on_train_end
|
|
||||||
```
|
|
||||||
|
|
||||||
### Callback Lifecycle
|
|
||||||
|
|
||||||
| Hook | Fires | Default callback |
|
|
||||||
|------|-------|-----------------|
|
|
||||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
|
||||||
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
|
|
||||||
| `on_batch_begin` | Every batch | — |
|
|
||||||
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
|
||||||
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
|
|
||||||
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
|
|
||||||
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
|
|
||||||
|
|
||||||
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
|
||||||
|
|
||||||
## Strategies
|
|
||||||
|
|
||||||
### SEQ (Pre-training)
|
|
||||||
|
|
||||||
Next-token cross-entropy with optional label smoothing:
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
|
||||||
$$
|
|
||||||
|
|
||||||
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
|
|
||||||
|
|
||||||
### SFT (Supervised Fine-Tuning)
|
|
||||||
|
|
||||||
Masked cross-entropy (`ignore_index=-100`) over response tokens:
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
|
||||||
$$
|
|
||||||
|
|
||||||
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
|
|
||||||
|
|
||||||
### DPO (Direct Preference Optimization)
|
|
||||||
|
|
||||||
Frozen reference model, preference margin via log-ratio:
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
|
||||||
$$
|
|
||||||
|
|
||||||
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
|
||||||
|
|
||||||
### GRPO (Group Relative Policy Optimization)
|
|
||||||
|
|
||||||
On-policy PPO with group-normalized advantages:
|
|
||||||
|
|
||||||
$$
|
|
||||||
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
|
||||||
$$
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
|
||||||
$$
|
|
||||||
|
|
||||||
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
|
|
||||||
|
|
||||||
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
|
||||||
|
|
||||||
## LR Schedulers
|
|
||||||
|
|
||||||
| Type | Class | Description |
|
|
||||||
|------|-------|-------------|
|
|
||||||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
|
||||||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
|
||||||
|
|
||||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
|
||||||
|
|
||||||
## Gradient Checkpointing
|
|
||||||
|
|
||||||
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
|
||||||
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
|
||||||
```
|
|
||||||
|
|
||||||
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
|
||||||
|
|
||||||
## Checkpoint
|
|
||||||
|
|
||||||
```
|
|
||||||
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
|
||||||
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
|
||||||
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
|
|
||||||
```
|
|
||||||
|
|
||||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
|
||||||
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
|
||||||
|
|
||||||
## TrainContextBuilder (Builder Pattern)
|
|
||||||
|
|
||||||
```python
|
|
||||||
context = (
|
|
||||||
TrainContextBuilder(config)
|
|
||||||
.with_resume_dir(resume_dir)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
|
||||||
```
|
|
||||||
|
|
||||||
- Loads checkpoint weights if provided
|
|
||||||
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
|
||||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
|
||||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
|
||||||
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
|
|
||||||
|
|
||||||
## Training CLI
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
|
||||||
--nprocs=4 \
|
|
||||||
--parallel_mode=ddp \
|
|
||||||
--train_type=seq \
|
|
||||||
--data_root_path=/path/to/dataset \
|
|
||||||
--param_path=/path/to/model \
|
|
||||||
--batch_per_device=4 \
|
|
||||||
--grad_accum_steps=8 \
|
|
||||||
--warmup_ratio=0.05 \
|
|
||||||
--max_lr=1e-4 \
|
|
||||||
--max_grad_norm=1.0 \
|
|
||||||
--adamw_beta1=0.9 \
|
|
||||||
--adamw_beta2=0.95 \
|
|
||||||
--adamw_weight_decay=0.01 \
|
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 281 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 590 KiB |
|
|
@ -1,34 +0,0 @@
|
||||||
__version__ = "1.3.7"
|
|
||||||
__author__ = "ViperEkura"
|
|
||||||
|
|
||||||
from astrai.config import (
|
|
||||||
AutoRegressiveLMConfig,
|
|
||||||
EncoderConfig,
|
|
||||||
TrainConfig,
|
|
||||||
)
|
|
||||||
from astrai.dataset import DatasetFactory
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.inference import (
|
|
||||||
GenerationRequest,
|
|
||||||
InferenceEngine,
|
|
||||||
)
|
|
||||||
from astrai.model import AutoModel, AutoRegressiveLM
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AutoRegressiveLM",
|
|
||||||
"AutoRegressiveLMConfig",
|
|
||||||
"EncoderConfig",
|
|
||||||
"TrainConfig",
|
|
||||||
"DatasetFactory",
|
|
||||||
"AutoTokenizer",
|
|
||||||
"GenerationRequest",
|
|
||||||
"InferenceEngine",
|
|
||||||
"Trainer",
|
|
||||||
"CallbackFactory",
|
|
||||||
"StrategyFactory",
|
|
||||||
"SchedulerFactory",
|
|
||||||
"BaseFactory",
|
|
||||||
"AutoModel",
|
|
||||||
]
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
from astrai.config.model_config import (
|
|
||||||
AutoRegressiveLMConfig,
|
|
||||||
BaseModelConfig,
|
|
||||||
ConfigFactory,
|
|
||||||
EncoderConfig,
|
|
||||||
)
|
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
OutputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.config.train_config import TrainConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseModelConfig",
|
|
||||||
"AutoRegressiveLMConfig",
|
|
||||||
"EncoderConfig",
|
|
||||||
"ConfigFactory",
|
|
||||||
"TrainConfig",
|
|
||||||
"InputConfig",
|
|
||||||
"OutputConfig",
|
|
||||||
"PipelineConfig",
|
|
||||||
"ProcessingConfig",
|
|
||||||
]
|
|
||||||
|
|
@ -1,98 +0,0 @@
|
||||||
import json
|
|
||||||
from dataclasses import MISSING, dataclass, fields
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseConfig:
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
d = {}
|
|
||||||
for fld in fields(self):
|
|
||||||
v = getattr(self, fld.name)
|
|
||||||
if isinstance(v, (str, int, float, bool)):
|
|
||||||
d[fld.name] = v
|
|
||||||
elif v is None:
|
|
||||||
d[fld.name] = None
|
|
||||||
elif isinstance(v, (dict, list, tuple)):
|
|
||||||
try:
|
|
||||||
val = list(v) if isinstance(v, tuple) else v
|
|
||||||
json.dumps(val)
|
|
||||||
d[fld.name] = val
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
pass
|
|
||||||
elif isinstance(v, BaseConfig):
|
|
||||||
d[fld.name] = v.to_dict()
|
|
||||||
elif hasattr(v, "__dataclass_fields__"):
|
|
||||||
sub = {}
|
|
||||||
for f in fields(v):
|
|
||||||
a = getattr(v, f.name)
|
|
||||||
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
|
||||||
d[fld.name] = sub
|
|
||||||
return d
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d: Dict[str, Any]) -> Self:
|
|
||||||
hints = get_type_hints(cls)
|
|
||||||
inst = cls.__new__(cls)
|
|
||||||
for fld in fields(cls):
|
|
||||||
if fld.name in d:
|
|
||||||
v = d[fld.name]
|
|
||||||
target = cls._unwrap_optional(hints.get(fld.name))
|
|
||||||
if target is not None:
|
|
||||||
try:
|
|
||||||
v = cls._coerce(v, target)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
pass
|
|
||||||
object.__setattr__(inst, fld.name, v)
|
|
||||||
elif fld.default is not MISSING:
|
|
||||||
object.__setattr__(inst, fld.name, fld.default)
|
|
||||||
elif fld.default_factory is not MISSING:
|
|
||||||
object.__setattr__(inst, fld.name, fld.default_factory())
|
|
||||||
else:
|
|
||||||
object.__setattr__(inst, fld.name, None)
|
|
||||||
return inst
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unwrap_optional(tp) -> Optional[type]:
|
|
||||||
if tp is None:
|
|
||||||
return None
|
|
||||||
origin = getattr(tp, "__origin__", None)
|
|
||||||
if origin is not None:
|
|
||||||
args = getattr(tp, "__args__", ())
|
|
||||||
non_none = [a for a in args if a is not type(None)]
|
|
||||||
return non_none[0] if non_none else None
|
|
||||||
return tp
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _coerce(value: Any, target_type: type) -> Any:
|
|
||||||
if target_type is bool and isinstance(value, bool):
|
|
||||||
return value
|
|
||||||
if (
|
|
||||||
target_type is int
|
|
||||||
and isinstance(value, (int, float))
|
|
||||||
and not isinstance(value, bool)
|
|
||||||
):
|
|
||||||
return int(value)
|
|
||||||
if (
|
|
||||||
target_type is float
|
|
||||||
and isinstance(value, (int, float))
|
|
||||||
and not isinstance(value, bool)
|
|
||||||
):
|
|
||||||
return float(value)
|
|
||||||
if target_type is str and isinstance(value, str):
|
|
||||||
return value
|
|
||||||
if isinstance(value, target_type):
|
|
||||||
return value
|
|
||||||
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
|
||||||
return target_type.from_dict(value)
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
return cls.from_dict(json.load(f))
|
|
||||||
|
|
||||||
def to_json(self, path: Union[str, Path]):
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
@ -1,92 +0,0 @@
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, Optional, Self
|
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigFactory(BaseFactory[BaseConfig]):
|
|
||||||
"""Factory that dispatches config classes by ``model_type``."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
|
||||||
model_type = raw.get("model_type") or "autoregressive_lm"
|
|
||||||
config_cls = cls.get_component_class(model_type)
|
|
||||||
return config_cls.from_dict(raw)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseModelConfig(BaseConfig):
|
|
||||||
"""Base config with ``model_type`` dispatch and file I/O."""
|
|
||||||
|
|
||||||
model_type: Optional[str] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(cls, config_path: str) -> Self:
|
|
||||||
with open(config_path, "r") as f:
|
|
||||||
raw: Dict[str, Any] = json.load(f)
|
|
||||||
return cls.from_dict(raw)
|
|
||||||
|
|
||||||
def to_file(self, config_path: str):
|
|
||||||
d = self.to_dict()
|
|
||||||
config_dict = {k: v for k, v in d.items() if v is not None}
|
|
||||||
with open(config_path, "w") as f:
|
|
||||||
json.dump(config_dict, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ConfigFactory.register("autoregressive_lm")
|
|
||||||
class AutoRegressiveLMConfig(BaseModelConfig):
|
|
||||||
"""Configuration for autoregressive language model."""
|
|
||||||
|
|
||||||
vocab_size: Optional[int] = None
|
|
||||||
dim: Optional[int] = None
|
|
||||||
n_layers: Optional[int] = None
|
|
||||||
norm_eps: Optional[float] = None
|
|
||||||
dim_ffn: Optional[int] = None
|
|
||||||
tie_weight: Optional[bool] = None
|
|
||||||
|
|
||||||
max_len: Optional[int] = None
|
|
||||||
rope_theta: Optional[float] = None
|
|
||||||
rope_scaling: Optional[dict] = None
|
|
||||||
|
|
||||||
attn_type: str = "gqa"
|
|
||||||
n_heads: Optional[int] = None
|
|
||||||
n_kv_heads: Optional[int] = None
|
|
||||||
use_qk_norm: Optional[bool] = None
|
|
||||||
use_gated_attention: Optional[bool] = None
|
|
||||||
|
|
||||||
kv_lora_rank: Optional[int] = None
|
|
||||||
qk_nope_head_dim: Optional[int] = None
|
|
||||||
qk_rope_head_dim: Optional[int] = None
|
|
||||||
|
|
||||||
ffn_type: str = "mlp"
|
|
||||||
n_routed_experts: Optional[int] = None
|
|
||||||
n_shared_experts: Optional[int] = None
|
|
||||||
n_activated_experts: Optional[int] = None
|
|
||||||
topk_method: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ConfigFactory.register("embedding")
|
|
||||||
class EncoderConfig(BaseModelConfig):
|
|
||||||
"""Configuration for embedding encoder model."""
|
|
||||||
|
|
||||||
vocab_size: Optional[int] = None
|
|
||||||
dim: Optional[int] = None
|
|
||||||
n_layers: Optional[int] = None
|
|
||||||
norm_eps: Optional[float] = None
|
|
||||||
dim_ffn: Optional[int] = None
|
|
||||||
|
|
||||||
max_len: Optional[int] = None
|
|
||||||
rope_theta: Optional[float] = None
|
|
||||||
rope_scaling: Optional[dict] = None
|
|
||||||
|
|
||||||
n_heads: Optional[int] = None
|
|
||||||
n_kv_heads: Optional[int] = None
|
|
||||||
use_qk_norm: Optional[bool] = None
|
|
||||||
use_gated_attention: Optional[bool] = None
|
|
||||||
|
|
||||||
pooling_type: Optional[str] = None
|
|
||||||
normalize_embeddings: Optional[bool] = None
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
"""Pipeline configuration for JSONL preprocessing."""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InputConfig(BaseConfig):
|
|
||||||
sections: Optional[List[Dict]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProcessingConfig(BaseConfig):
|
|
||||||
max_seq_len: int = 2048
|
|
||||||
min_chars: int = 50
|
|
||||||
max_chars: int = 2_000_000
|
|
||||||
max_items: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OutputConfig(BaseConfig):
|
|
||||||
domain_key: Optional[str] = None
|
|
||||||
storage_format: str = "bin"
|
|
||||||
max_tokens_per_shard: int = 100_000_000
|
|
||||||
dtype: Dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineConfig(BaseConfig):
|
|
||||||
version: int = 1
|
|
||||||
input: InputConfig = field(default_factory=InputConfig)
|
|
||||||
mask: Dict[str, str] = field(default_factory=dict)
|
|
||||||
mask_default: str = "mask"
|
|
||||||
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
|
||||||
output: OutputConfig = field(default_factory=OutputConfig)
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
from dataclasses import dataclass, field, fields
|
|
||||||
from typing import Callable, List, Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
|
||||||
from astrai.model.components.lora import LoRAConfig
|
|
||||||
|
|
||||||
|
|
||||||
def required(**kw):
|
|
||||||
return {"required": True, **kw}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainConfig(BaseConfig):
|
|
||||||
# basic setting
|
|
||||||
model_fn: Callable[[], nn.Module] = field(
|
|
||||||
default=None, metadata=required(help="Model factory for training.")
|
|
||||||
)
|
|
||||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
|
||||||
dataset: Dataset = field(
|
|
||||||
default=None, metadata=required(help="Dataset for training.")
|
|
||||||
)
|
|
||||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
|
||||||
default=None, metadata=required(help="Optimizer factory for training.")
|
|
||||||
)
|
|
||||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
|
||||||
default=None, metadata=required(help="Scheduler factory for training.")
|
|
||||||
)
|
|
||||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
|
||||||
batch_per_device: int = field(
|
|
||||||
default=4, metadata={"help": "Batch size per device."}
|
|
||||||
)
|
|
||||||
grad_accum_steps: int = field(
|
|
||||||
default=1, metadata={"help": "Number of iterations between steps."}
|
|
||||||
)
|
|
||||||
max_grad_norm: float = field(
|
|
||||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
|
||||||
)
|
|
||||||
gradient_checkpointing_modules: list = field(
|
|
||||||
default_factory=list,
|
|
||||||
metadata={"help": "Module types to enable activation checkpointing for."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# checkpoint setting
|
|
||||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
|
||||||
start_batch: int = field(
|
|
||||||
default=0, metadata={"help": "Start batch iteration for training."}
|
|
||||||
)
|
|
||||||
ckpt_dir: str = field(
|
|
||||||
default="./checkpoint", metadata={"help": "Checkpoint directory."}
|
|
||||||
)
|
|
||||||
ckpt_interval: int = field(
|
|
||||||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# lora setting
|
|
||||||
lora: Optional[LoRAConfig] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "LoRA config. None means full fine-tuning."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# metric setting
|
|
||||||
log_dir: str = field(
|
|
||||||
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
|
||||||
)
|
|
||||||
log_interval: int = field(
|
|
||||||
default=100,
|
|
||||||
metadata={"help": "Number of batch iterations between metric logs."},
|
|
||||||
)
|
|
||||||
metrics: List[str] = field(
|
|
||||||
default_factory=lambda: ["loss", "lr"],
|
|
||||||
metadata={"help": "Metrics to record during training."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# dataloader setting
|
|
||||||
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
|
||||||
num_workers: int = field(
|
|
||||||
default=0, metadata={"help": "Number of workers for dataloader."}
|
|
||||||
)
|
|
||||||
prefetch_factor: Optional[int] = field(
|
|
||||||
default=None, metadata={"help": "Prefetch factor for dataloader."}
|
|
||||||
)
|
|
||||||
pin_memory: bool = field(
|
|
||||||
default=False, metadata={"help": "Pin memory for dataloader."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# distributed training
|
|
||||||
nprocs: int = field(
|
|
||||||
default=1, metadata={"help": "Number of processes for distributed training."}
|
|
||||||
)
|
|
||||||
backend: str = field(
|
|
||||||
default="nccl", metadata={"help": "Distributed training backend."}
|
|
||||||
)
|
|
||||||
master_addr: str = field(
|
|
||||||
default="localhost",
|
|
||||||
metadata={"help": "Master address for distributed training."},
|
|
||||||
)
|
|
||||||
master_port: str = field(
|
|
||||||
default="29500", metadata={"help": "Master port for distributed training."}
|
|
||||||
)
|
|
||||||
parallel_mode: str = field(
|
|
||||||
default="none",
|
|
||||||
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
|
||||||
)
|
|
||||||
start_method: str = field(
|
|
||||||
default="spawn",
|
|
||||||
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# others
|
|
||||||
device_type: str = field(
|
|
||||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
|
||||||
)
|
|
||||||
val_dataset: Optional[Dataset] = field(
|
|
||||||
default=None, metadata={"help": "Dataset for validation."}
|
|
||||||
)
|
|
||||||
val_step: int = field(
|
|
||||||
default=1000,
|
|
||||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
|
||||||
)
|
|
||||||
|
|
||||||
executor_kwargs: dict = field(
|
|
||||||
default_factory=dict,
|
|
||||||
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
|
||||||
)
|
|
||||||
extra_kwargs: dict = field(
|
|
||||||
default_factory=dict, metadata={"help": "Other arguments."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.validate()
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
for fld in fields(self):
|
|
||||||
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
|
||||||
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
from astrai.dataset.dataset import (
|
|
||||||
BaseDataset,
|
|
||||||
DatasetFactory,
|
|
||||||
)
|
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
|
||||||
from astrai.dataset.storage import (
|
|
||||||
H5Store,
|
|
||||||
MmapStore,
|
|
||||||
Store,
|
|
||||||
StoreFactory,
|
|
||||||
detect_format,
|
|
||||||
load_bin,
|
|
||||||
load_h5,
|
|
||||||
save_bin,
|
|
||||||
save_h5,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseDataset",
|
|
||||||
"DatasetFactory",
|
|
||||||
"Store",
|
|
||||||
"StoreFactory",
|
|
||||||
"H5Store",
|
|
||||||
"MmapStore",
|
|
||||||
"detect_format",
|
|
||||||
"save_h5",
|
|
||||||
"load_h5",
|
|
||||||
"save_bin",
|
|
||||||
"load_bin",
|
|
||||||
"ResumableDistributedSampler",
|
|
||||||
]
|
|
||||||
|
|
@ -1,308 +0,0 @@
|
||||||
"""Dataset implementations with factory pattern for training."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from astrai.dataset.storage import (
|
|
||||||
Store,
|
|
||||||
StoreFactory,
|
|
||||||
detect_format,
|
|
||||||
)
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
|
||||||
"""Abstract base class for all dataset types.
|
|
||||||
|
|
||||||
Implements common functionality for window-based data fetching.
|
|
||||||
Uses a storage abstraction for format-agnostic data loading.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__()
|
|
||||||
self.window_size = window_size
|
|
||||||
self.stride = stride
|
|
||||||
self.storage: Optional[Store] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
"""Return required storage keys for this dataset type.
|
|
||||||
|
|
||||||
Subclasses should override to specify expected keys.
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _validate_keys(self):
|
|
||||||
if not self.required_keys:
|
|
||||||
return
|
|
||||||
actual_keys = set(self.storage.keys)
|
|
||||||
missing = [k for k in self.required_keys if k not in actual_keys]
|
|
||||||
if missing:
|
|
||||||
raise KeyError(
|
|
||||||
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
|
|
||||||
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
|
|
||||||
f"Missing: {missing}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, load_path: str, storage_type: Optional[str] = None):
|
|
||||||
"""Load dataset from the given path.
|
|
||||||
|
|
||||||
Auto-detects the storage format if not specified.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
load_path: Path to the data directory or file
|
|
||||||
storage_type: Force a specific storage type ("h5", "bin"),
|
|
||||||
or None for auto-detection
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If the loaded storage is missing required keys.
|
|
||||||
"""
|
|
||||||
if storage_type is None:
|
|
||||||
storage_type = detect_format(load_path)
|
|
||||||
self.storage = StoreFactory.create(storage_type)
|
|
||||||
self._load_path = load_path
|
|
||||||
self.storage.load(load_path)
|
|
||||||
self._validate_keys()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def count(self) -> int:
|
|
||||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
|
||||||
if self.storage is None:
|
|
||||||
return 0
|
|
||||||
return len(self.storage)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def keys(self) -> List[str]:
|
|
||||||
"""Return the available data keys."""
|
|
||||||
if self.storage is None:
|
|
||||||
return []
|
|
||||||
return self.storage.keys
|
|
||||||
|
|
||||||
def get_index(self, index: int) -> tuple:
|
|
||||||
"""Calculate begin and end indices for a sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index: Sample index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (begin_idx, end_idx)
|
|
||||||
"""
|
|
||||||
if self.storage is None:
|
|
||||||
raise RuntimeError("Dataset not loaded, call load() first")
|
|
||||||
total = len(self.storage)
|
|
||||||
if total <= self.window_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Data too short: {total} tokens <= window_size {self.window_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
begin_idx = min(index * self.stride, total - 1 - self.window_size)
|
|
||||||
end_idx = min(begin_idx + self.window_size, total - 1)
|
|
||||||
|
|
||||||
return begin_idx, end_idx
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
||||||
"""Get a single sample by index.
|
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
if self.storage is None:
|
|
||||||
return 0
|
|
||||||
total = len(self.storage)
|
|
||||||
if total <= self.window_size:
|
|
||||||
return 0
|
|
||||||
return (total - 1 - self.window_size) // self.stride + 1
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetFactory(BaseFactory["BaseDataset"]):
|
|
||||||
"""Factory class for creating dataset instances.
|
|
||||||
|
|
||||||
Supports decorator-based registration for extensible dataset types.
|
|
||||||
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
|
||||||
when their classes are defined with the decorator.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
@DatasetFactory.register("custom")
|
|
||||||
class CustomDataset(BaseDataset):
|
|
||||||
...
|
|
||||||
|
|
||||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, dataset_cls: type):
|
|
||||||
"""Validate that the dataset class inherits from BaseDataset."""
|
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
|
||||||
"""Create a dataset instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
|
||||||
window_size: Window size for data sampling
|
|
||||||
stride: Stride between consecutive samples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dataset instance
|
|
||||||
"""
|
|
||||||
return super().create(train_type, window_size, stride)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
train_type: str,
|
|
||||||
load_path: str,
|
|
||||||
window_size: int,
|
|
||||||
stride: Optional[int] = None,
|
|
||||||
storage_type: Optional[str] = None,
|
|
||||||
) -> "BaseDataset":
|
|
||||||
"""Create and load a dataset in one step.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_type: Type of training dataset
|
|
||||||
load_path: Path to the data file
|
|
||||||
window_size: Window size for data sampling
|
|
||||||
stride: Stride between consecutive samples (default: same as window_size)
|
|
||||||
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loaded dataset instance
|
|
||||||
"""
|
|
||||||
if stride is None:
|
|
||||||
stride = window_size
|
|
||||||
|
|
||||||
dataset = cls.create(train_type, window_size, stride)
|
|
||||||
dataset.load(load_path, storage_type=storage_type)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available_types(cls) -> list:
|
|
||||||
"""Return list of registered dataset type names."""
|
|
||||||
return cls.list_registered()
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("seq")
|
|
||||||
class SEQDataset(BaseDataset):
|
|
||||||
"""Dataset for sequential next-token prediction training."""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["sequence"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
||||||
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y}
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("sft")
|
|
||||||
class SFTDataset(BaseDataset):
|
|
||||||
"""Dataset for supervised fine-tuning with loss masking."""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["sequence", "loss_mask"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("dpo")
|
|
||||||
class DPODataset(BaseDataset):
|
|
||||||
"""Dataset for Direct Preference Optimization training."""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
|
||||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
|
||||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
|
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
|
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"chosen": chosen,
|
|
||||||
"rejected": rejected,
|
|
||||||
"chosen_mask": chosen_mask,
|
|
||||||
"rejected_mask": rejected_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("grpo")
|
|
||||||
class GRPODataset(BaseDataset):
|
|
||||||
"""Dataset for Group Relative Policy Optimization training."""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["prompts", "responses", "masks", "rewards"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
|
||||||
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"prompts": prompts,
|
|
||||||
"responses": responses,
|
|
||||||
"masks": masks,
|
|
||||||
"rewards": rewards,
|
|
||||||
}
|
|
||||||
|
|
@ -1,264 +0,0 @@
|
||||||
"""Storage backends for different data formats.
|
|
||||||
|
|
||||||
Layers:
|
|
||||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
|
||||||
return Dict[str, List[Tensor]] — format-specific, no state
|
|
||||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
|
||||||
Dict[str, List[Tensor]] per key via _normalize(),
|
|
||||||
fetch() uses bisect across segments — no forced concat
|
|
||||||
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
|
||||||
|
|
||||||
Key properties:
|
|
||||||
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
|
||||||
datasets larger than RAM
|
|
||||||
- Explicit length: _length = min(total elements across keys), set at load,
|
|
||||||
__len__ returns O(1)
|
|
||||||
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
|
||||||
workers share OS page-cache pages
|
|
||||||
"""
|
|
||||||
|
|
||||||
import bisect
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Union
|
|
||||||
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
|
||||||
with h5py.File(full_file_path, "w") as f:
|
|
||||||
for key, tensors in tensor_group.items():
|
|
||||||
grp = f.create_group(key)
|
|
||||||
for idx, tensor in enumerate(tensors):
|
|
||||||
arr = tensor.cpu().numpy()
|
|
||||||
grp.create_dataset(f"data_{idx}", data=arr)
|
|
||||||
|
|
||||||
|
|
||||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
|
||||||
|
|
||||||
root_path = Path(file_path)
|
|
||||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
|
||||||
|
|
||||||
for h5_file in h5_files:
|
|
||||||
with h5py.File(h5_file, "r") as f:
|
|
||||||
for key in f.keys():
|
|
||||||
grp = f[key]
|
|
||||||
dsets = []
|
|
||||||
for dset_name in grp.keys():
|
|
||||||
dset = grp[dset_name]
|
|
||||||
tensor = torch.from_numpy(dset[:])
|
|
||||||
if share_memory:
|
|
||||||
tensor = tensor.share_memory_()
|
|
||||||
dsets.append(tensor)
|
|
||||||
|
|
||||||
if tensor_group.get(key) is None:
|
|
||||||
tensor_group[key] = []
|
|
||||||
tensor_group[key].extend(dsets)
|
|
||||||
|
|
||||||
return tensor_group
|
|
||||||
|
|
||||||
|
|
||||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
meta = {}
|
|
||||||
for key, tensors in tensor_group.items():
|
|
||||||
cat = torch.cat(tensors, dim=0)
|
|
||||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
|
||||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
|
||||||
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
|
||||||
json.dump(meta, f)
|
|
||||||
|
|
||||||
|
|
||||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
|
||||||
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
segments: Dict[str, List[Tensor]] = {}
|
|
||||||
for key, info in meta.items():
|
|
||||||
arr = np.memmap(
|
|
||||||
os.path.join(file_path, f"{key}.bin"),
|
|
||||||
dtype=info["dtype"],
|
|
||||||
mode="r+",
|
|
||||||
shape=tuple(info["shape"]),
|
|
||||||
)
|
|
||||||
segments[key] = [torch.from_numpy(arr)]
|
|
||||||
return segments
|
|
||||||
|
|
||||||
|
|
||||||
def detect_format(load_path: str) -> str:
|
|
||||||
"""Auto-detect storage format from files in the directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
load_path: Directory or file path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Format string ("h5" or "bin")
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If no supported data files are found
|
|
||||||
"""
|
|
||||||
root = Path(load_path)
|
|
||||||
if root.is_file():
|
|
||||||
suffix = root.suffix.lower()
|
|
||||||
if suffix in (".h5", ".hdf5"):
|
|
||||||
return "h5"
|
|
||||||
raise ValueError(f"Unsupported file format: {suffix}")
|
|
||||||
|
|
||||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
|
||||||
if h5_files:
|
|
||||||
return "h5"
|
|
||||||
bin_files = list(root.rglob("*.bin"))
|
|
||||||
if bin_files:
|
|
||||||
has_meta = (root / "meta.json").exists() or len(
|
|
||||||
list(root.rglob("meta.json"))
|
|
||||||
) > 0
|
|
||||||
if has_meta:
|
|
||||||
return "bin"
|
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
|
||||||
|
|
||||||
|
|
||||||
class Store(ABC):
|
|
||||||
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
|
||||||
|
|
||||||
Each key maps to one or more tensor segments (no forced concatenation).
|
|
||||||
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
|
||||||
total element count across all keys.
|
|
||||||
|
|
||||||
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
|
||||||
via ``_normalize()``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._data: Dict[str, List[Tensor]] = {}
|
|
||||||
self._cum: Dict[str, List[int]] = {}
|
|
||||||
self._length: int = 0
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load(self, path: str) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def keys(self) -> List[str]:
|
|
||||||
return list(self._data.keys())
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self._length
|
|
||||||
|
|
||||||
def fetch(
|
|
||||||
self,
|
|
||||||
begin: int,
|
|
||||||
end: int,
|
|
||||||
keys: Union[str, List[str]],
|
|
||||||
):
|
|
||||||
if not self._data:
|
|
||||||
raise RuntimeError("Store not loaded")
|
|
||||||
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
|
||||||
raise ValueError(
|
|
||||||
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
|
||||||
)
|
|
||||||
if isinstance(keys, str):
|
|
||||||
return self._fetch_key(keys, begin, end)
|
|
||||||
return {k: self._fetch_key(k, begin, end) for k in keys}
|
|
||||||
|
|
||||||
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
|
||||||
"""Fetch slice [begin, end) across potentially multiple segments."""
|
|
||||||
segments = self._data[key]
|
|
||||||
cum = self._cum[key]
|
|
||||||
seg_start = bisect.bisect_right(cum, begin)
|
|
||||||
seg_end = bisect.bisect_left(cum, end)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for i in range(seg_start, seg_end + 1):
|
|
||||||
prev = cum[i - 1] if i > 0 else 0
|
|
||||||
s = max(begin - prev, 0)
|
|
||||||
e = min(end - prev, segments[i].shape[0])
|
|
||||||
results.append(segments[i][s:e])
|
|
||||||
|
|
||||||
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
|
||||||
|
|
||||||
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
|
||||||
"""Register segments and pre-compute cumulative lengths.
|
|
||||||
|
|
||||||
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
|
||||||
large datasets. Sets ``self._length`` to the minimum total
|
|
||||||
element count across all keys.
|
|
||||||
"""
|
|
||||||
for key, tensors in raw.items():
|
|
||||||
self._data[key] = tensors
|
|
||||||
cum = []
|
|
||||||
total = 0
|
|
||||||
for t in tensors:
|
|
||||||
total += t.shape[0]
|
|
||||||
cum.append(total)
|
|
||||||
self._cum[key] = cum
|
|
||||||
self._length = (
|
|
||||||
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
|
||||||
if self._cum
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StoreFactory(BaseFactory["Store"]):
|
|
||||||
"""Factory for creating Store instances by type name.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
@StoreFactory.register("custom")
|
|
||||||
class CustomStore(Store):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, store_cls: type):
|
|
||||||
if not issubclass(store_cls, Store):
|
|
||||||
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("h5")
|
|
||||||
class H5Store(Store):
|
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
|
||||||
|
|
||||||
def load(self, path: str):
|
|
||||||
self._normalize(load_h5(path))
|
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("bin")
|
|
||||||
class MmapStore(Store):
|
|
||||||
"""Memory-mapped binary storage backend.
|
|
||||||
|
|
||||||
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
|
||||||
No per-process memory duplication — all DataLoader workers share the
|
|
||||||
same OS page-cache pages.
|
|
||||||
|
|
||||||
Format on disk::
|
|
||||||
|
|
||||||
data_root/
|
|
||||||
meta.json # {key: {shape, dtype}, ...}
|
|
||||||
<key>.bin # raw numpy array, one per key
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load(self, path: str):
|
|
||||||
self._mmap_refs = []
|
|
||||||
root = Path(path)
|
|
||||||
all_raw: Dict[str, List[Tensor]] = {}
|
|
||||||
meta_paths = list(root.rglob("meta.json"))
|
|
||||||
for meta_path in meta_paths:
|
|
||||||
raw = load_bin(str(meta_path.parent))
|
|
||||||
for key, tensors in raw.items():
|
|
||||||
if key not in all_raw:
|
|
||||||
all_raw[key] = []
|
|
||||||
all_raw[key].extend(tensors)
|
|
||||||
if not meta_paths:
|
|
||||||
raise FileNotFoundError(f"No meta.json found under {path}")
|
|
||||||
self._normalize(all_raw)
|
|
||||||
for tensors in self._data.values():
|
|
||||||
self._mmap_refs.extend(tensors)
|
|
||||||
|
|
@ -1,226 +0,0 @@
|
||||||
"""Base factory class for extensible component registration."""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
from abc import ABC
|
|
||||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class Registry:
|
|
||||||
"""Flexible registry for component classes with category and priority support.
|
|
||||||
|
|
||||||
This registry stores component classes with optional metadata (category, priority).
|
|
||||||
It provides methods for registration, retrieval, and listing with filtering.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._entries = {} # name -> (component_cls, category, priority)
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
component_cls: Type,
|
|
||||||
category: Optional[str] = None,
|
|
||||||
priority: int = 0,
|
|
||||||
):
|
|
||||||
"""Register a component class with optional category and priority."""
|
|
||||||
if name in self._entries:
|
|
||||||
raise ValueError(f"Component '{name}' is already registered")
|
|
||||||
self._entries[name] = (component_cls, category, priority)
|
|
||||||
|
|
||||||
def get(self, name: str) -> Type:
|
|
||||||
"""Get component class by name."""
|
|
||||||
if name not in self._entries:
|
|
||||||
raise KeyError(f"Component '{name}' not found in registry")
|
|
||||||
return self._entries[name][0]
|
|
||||||
|
|
||||||
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
|
|
||||||
"""Get component class with its metadata."""
|
|
||||||
entry = self._entries.get(name)
|
|
||||||
if entry is None:
|
|
||||||
raise KeyError(f"Component '{name}' not found in registry")
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def contains(self, name: str) -> bool:
|
|
||||||
"""Check if a name is registered."""
|
|
||||||
return name in self._entries
|
|
||||||
|
|
||||||
def list_names(self) -> List[str]:
|
|
||||||
"""Return list of registered component names."""
|
|
||||||
return sorted(self._entries.keys())
|
|
||||||
|
|
||||||
def list_by_category(self, category: str) -> List[str]:
|
|
||||||
"""Return names of components belonging to a specific category."""
|
|
||||||
return sorted(
|
|
||||||
name for name, (_, cat, _) in self._entries.items() if cat == category
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_by_priority(self, reverse: bool = False) -> List[str]:
|
|
||||||
"""Return names sorted by priority (default ascending)."""
|
|
||||||
return sorted(
|
|
||||||
self._entries.keys(),
|
|
||||||
key=lambda name: self._entries[name][2],
|
|
||||||
reverse=reverse,
|
|
||||||
)
|
|
||||||
|
|
||||||
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
|
|
||||||
"""Return raw entries dictionary."""
|
|
||||||
return self._entries.copy()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFactory(ABC, Generic[T]):
|
|
||||||
"""Generic factory class for component registration and creation.
|
|
||||||
|
|
||||||
This base class provides a decorator-based registration pattern
|
|
||||||
for creating extensible component factories.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
class MyFactory(BaseFactory[MyBaseClass]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@MyFactory.register("custom")
|
|
||||||
class CustomComponent(MyBaseClass):
|
|
||||||
...
|
|
||||||
|
|
||||||
component = MyFactory.create("custom", *args, **kwargs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registry: Registry
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
cls._registry = Registry()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(
|
|
||||||
cls, name: str, category: Optional[str] = None, priority: int = 0
|
|
||||||
) -> Callable[[Type[T]], Type[T]]:
|
|
||||||
"""Decorator to register a component class with optional category and priority.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Registration name for the component
|
|
||||||
category: Optional category for grouping components
|
|
||||||
priority: Priority for ordering (default 0)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorator function that registers the component class
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the decorated class doesn't inherit from the base type
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
|
||||||
cls._validate_component(component_cls)
|
|
||||||
cls._registry.register(
|
|
||||||
name, component_cls, category=category, priority=priority
|
|
||||||
)
|
|
||||||
return component_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, name: str, *args, **kwargs) -> T:
|
|
||||||
"""Create a component instance by name.
|
|
||||||
|
|
||||||
Filters kwargs to match the component's __init__ signature,
|
|
||||||
so components don't need to declare **kwargs just to absorb
|
|
||||||
parameters meant for other components.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Registered name of the component
|
|
||||||
*args: Positional arguments passed to component constructor
|
|
||||||
**kwargs: Keyword arguments passed to component constructor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Component instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the component name is not registered
|
|
||||||
"""
|
|
||||||
if not cls._registry.contains(name):
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown component: '{name}'. "
|
|
||||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
|
||||||
)
|
|
||||||
component_cls = cls._registry.get(name)
|
|
||||||
sig = inspect.signature(component_cls.__init__)
|
|
||||||
has_var_kwargs = any(
|
|
||||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
|
||||||
)
|
|
||||||
if not has_var_kwargs:
|
|
||||||
valid = {
|
|
||||||
p.name
|
|
||||||
for p in sig.parameters.values()
|
|
||||||
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
|
||||||
}
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
|
||||||
return component_cls(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, component_cls: Type[T]):
|
|
||||||
"""Validate that the component class is valid for this factory.
|
|
||||||
|
|
||||||
Override this method in subclasses to add custom validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_cls: Component class to validate
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the component class is invalid
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_component_class(cls, name: str) -> Type[T]:
|
|
||||||
"""Get the registered component class by name without instantiating it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Registered name of the component
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The component class itself
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the component name is not registered
|
|
||||||
"""
|
|
||||||
if not cls._registry.contains(name):
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown component: '{name}'. "
|
|
||||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
|
||||||
)
|
|
||||||
return cls._registry.get(name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_registered(cls) -> list:
|
|
||||||
"""List all registered component names.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of registered component names
|
|
||||||
"""
|
|
||||||
return cls._registry.list_names()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_registered(cls, name: str) -> bool:
|
|
||||||
"""Check if a component name is registered.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Component name to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if registered, False otherwise
|
|
||||||
"""
|
|
||||||
return cls._registry.contains(name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_by_category(cls, category: str) -> List[str]:
|
|
||||||
"""List registered component names in a category."""
|
|
||||||
return cls._registry.list_by_category(category)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
|
||||||
"""List registered component names sorted by priority."""
|
|
||||||
return cls._registry.list_by_priority(reverse)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Registry", "BaseFactory"]
|
|
||||||
|
|
@ -1,85 +0,0 @@
|
||||||
"""Inference module for continuous batching.
|
|
||||||
|
|
||||||
Layers:
|
|
||||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
|
||||||
- api/: HTTP orchestration (ProtocolHandler, server)
|
|
||||||
- protocols/: Response builders (OpenAI, Anthropic)
|
|
||||||
- transport/: SSE transport utilities
|
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
|
||||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from astrai.inference.api import (
|
|
||||||
AnthropicMessage,
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatMessage,
|
|
||||||
GenContext,
|
|
||||||
MessagesRequest,
|
|
||||||
ProtocolHandler,
|
|
||||||
StopChecker,
|
|
||||||
app,
|
|
||||||
run_server,
|
|
||||||
)
|
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.core import (
|
|
||||||
STOP,
|
|
||||||
Allocator,
|
|
||||||
Executor,
|
|
||||||
InferenceScheduler,
|
|
||||||
KVCache,
|
|
||||||
KvcacheView,
|
|
||||||
PagePool,
|
|
||||||
PrefixCache,
|
|
||||||
Storage,
|
|
||||||
Task,
|
|
||||||
TaskManager,
|
|
||||||
TaskStatus,
|
|
||||||
TaskTable,
|
|
||||||
page_hash,
|
|
||||||
)
|
|
||||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
|
||||||
from astrai.inference.sample import (
|
|
||||||
BaseSamplingStrategy,
|
|
||||||
SamplingPipeline,
|
|
||||||
TemperatureStrategy,
|
|
||||||
TopKStrategy,
|
|
||||||
TopPStrategy,
|
|
||||||
sample,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"InferenceEngine",
|
|
||||||
"GenerationRequest",
|
|
||||||
"InferenceScheduler",
|
|
||||||
"Executor",
|
|
||||||
"STOP",
|
|
||||||
"Task",
|
|
||||||
"TaskManager",
|
|
||||||
"TaskStatus",
|
|
||||||
"Allocator",
|
|
||||||
"KVCache",
|
|
||||||
"KvcacheView",
|
|
||||||
"PagePool",
|
|
||||||
"PrefixCache",
|
|
||||||
"Storage",
|
|
||||||
"TaskTable",
|
|
||||||
"page_hash",
|
|
||||||
"sample",
|
|
||||||
"BaseSamplingStrategy",
|
|
||||||
"TemperatureStrategy",
|
|
||||||
"TopKStrategy",
|
|
||||||
"TopPStrategy",
|
|
||||||
"SamplingPipeline",
|
|
||||||
"ProtocolHandler",
|
|
||||||
"StopChecker",
|
|
||||||
"GenContext",
|
|
||||||
"OpenAIResponseBuilder",
|
|
||||||
"AnthropicResponseBuilder",
|
|
||||||
"ChatMessage",
|
|
||||||
"ChatCompletionRequest",
|
|
||||||
"AnthropicMessage",
|
|
||||||
"MessagesRequest",
|
|
||||||
"app",
|
|
||||||
"run_server",
|
|
||||||
]
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
||||||
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
|
||||||
|
|
||||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
|
||||||
from astrai.inference.api.server import (
|
|
||||||
AnthropicMessage,
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatMessage,
|
|
||||||
MessagesRequest,
|
|
||||||
app,
|
|
||||||
run_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ProtocolHandler",
|
|
||||||
"StopChecker",
|
|
||||||
"GenContext",
|
|
||||||
"AnthropicMessage",
|
|
||||||
"ChatCompletionRequest",
|
|
||||||
"ChatMessage",
|
|
||||||
"MessagesRequest",
|
|
||||||
"app",
|
|
||||||
"run_server",
|
|
||||||
]
|
|
||||||
|
|
@ -1,141 +0,0 @@
|
||||||
"""Anthropic message completion response builder."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
|
||||||
GenContext,
|
|
||||||
ResponseBuilder,
|
|
||||||
StopInfo,
|
|
||||||
sse_event,
|
|
||||||
)
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicResponseBuilder(ResponseBuilder):
|
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
messages: List[Dict[str, str]] = []
|
|
||||||
system = getattr(request, "system", None)
|
|
||||||
if system:
|
|
||||||
messages.append({"role": "system", "content": system})
|
|
||||||
for m in request.messages:
|
|
||||||
text = _extract_text(m.content)
|
|
||||||
if text:
|
|
||||||
messages.append({"role": m.role, "content": text})
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
ctx = GenContext(
|
|
||||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
|
||||||
created=int(time.time()),
|
|
||||||
model=request.model,
|
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
|
||||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
|
||||||
return prompt, ctx, stop_sequences
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [],
|
|
||||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
event="message_start",
|
|
||||||
),
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
event="content_block_start",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
return sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": token},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
events: List[str] = []
|
|
||||||
if stop.matched:
|
|
||||||
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
|
||||||
unyielded = trimmed[len(stop.yielded) :]
|
|
||||||
if unyielded:
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": unyielded},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{"type": "content_block_stop", "index": 0},
|
|
||||||
event="content_block_stop",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {
|
|
||||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
||||||
"stop_sequence": stop.matched,
|
|
||||||
},
|
|
||||||
"usage": {"output_tokens": ctx.completion_tokens},
|
|
||||||
},
|
|
||||||
event="message_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
|
||||||
return events
|
|
||||||
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
if stop.matched:
|
|
||||||
content = content[: content.rfind(stop.matched)]
|
|
||||||
return {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
||||||
"stop_sequence": stop.matched,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": ctx.prompt_tokens,
|
|
||||||
"output_tokens": ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
"""OpenAI chat completion response builder."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
|
||||||
GenContext,
|
|
||||||
ResponseBuilder,
|
|
||||||
StopInfo,
|
|
||||||
sse_event,
|
|
||||||
)
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_UNSUPPORTED_PARAMS = (
|
|
||||||
"n",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"logit_bias",
|
|
||||||
"user",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponseBuilder(ResponseBuilder):
|
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
||||||
self._model = request.model
|
|
||||||
|
|
||||||
for param in _UNSUPPORTED_PARAMS:
|
|
||||||
value = getattr(request, param, None)
|
|
||||||
fields = getattr(type(request), "model_fields", {})
|
|
||||||
default = fields[param].default if param in fields else None
|
|
||||||
if value is not None and value != default:
|
|
||||||
logger.warning(
|
|
||||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
|
||||||
param,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
if value is not None and value != default:
|
|
||||||
logger.warning(
|
|
||||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
|
||||||
param,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = GenContext(
|
|
||||||
resp_id=self._resp_id,
|
|
||||||
created=int(time.time()),
|
|
||||||
model=self._model,
|
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
|
||||||
stop = request.stop
|
|
||||||
stop_sequences = (
|
|
||||||
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
||||||
)
|
|
||||||
return prompt, ctx, stop_sequences
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant"},
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
return sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 0,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
||||||
}
|
|
||||||
),
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": content},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
@ -1,182 +0,0 @@
|
||||||
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
|
||||||
|
|
||||||
ProtocolHandler orchestrates the async generation loop and delegates
|
|
||||||
protocol-specific formatting to a ResponseBuilder.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
|
|
||||||
|
|
||||||
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
|
||||||
lines: List[str] = []
|
|
||||||
if event:
|
|
||||||
lines.append(f"event: {event}")
|
|
||||||
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
|
||||||
lines.append("")
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
def sse_done() -> str:
|
|
||||||
return "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GenContext:
|
|
||||||
"""Per-generation metadata passed to builder format methods."""
|
|
||||||
|
|
||||||
resp_id: str
|
|
||||||
created: int
|
|
||||||
model: str
|
|
||||||
prompt_tokens: int
|
|
||||||
completion_tokens: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StopInfo:
|
|
||||||
"""Stop-check result passed to format_stream_end / format_response."""
|
|
||||||
|
|
||||||
matched: Optional[str] = None
|
|
||||||
body: str = ""
|
|
||||||
yielded: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class StopChecker:
|
|
||||||
"""Scans accumulated text for stop sequence matches."""
|
|
||||||
|
|
||||||
def __init__(self, sequences: List[str]):
|
|
||||||
self._sequences = [s for s in sequences if s]
|
|
||||||
|
|
||||||
def check(self, text: str) -> Optional[str]:
|
|
||||||
for seq in self._sequences:
|
|
||||||
if seq in text:
|
|
||||||
return seq
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseBuilder(ABC):
|
|
||||||
"""Interface for protocol-specific response formatting.
|
|
||||||
|
|
||||||
A new protocol requires one concrete builder implementing 5 methods.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
"""SSE events that open the stream."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
"""SSE event for a single generated token."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
"""SSE events that close the stream."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""JSON response body for non-streaming mode."""
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolHandler:
|
|
||||||
"""Orchestrates the generation loop, delegates formatting to a builder.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
|
||||||
response = await handler.handle()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
|
||||||
):
|
|
||||||
self.request = request
|
|
||||||
self.engine = engine
|
|
||||||
self.builder = builder
|
|
||||||
|
|
||||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
|
||||||
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
|
||||||
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
|
||||||
|
|
||||||
agen = self.engine.generate_async(
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=self.request.max_tokens,
|
|
||||||
temperature=self.request.temperature,
|
|
||||||
top_p=self.request.top_p,
|
|
||||||
top_k=self.request.top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.request.stream:
|
|
||||||
return self._handle_stream(agen, ctx, stop_sequences)
|
|
||||||
else:
|
|
||||||
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
|
||||||
|
|
||||||
def _handle_stream(
|
|
||||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
|
||||||
) -> StreamingResponse:
|
|
||||||
checker = StopChecker(stop_sequences)
|
|
||||||
|
|
||||||
async def event_stream():
|
|
||||||
for event in self.builder.format_stream_start(ctx):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
body = ""
|
|
||||||
yielded = ""
|
|
||||||
matched = None
|
|
||||||
async for token in agen:
|
|
||||||
body += token
|
|
||||||
|
|
||||||
matched = checker.check(body)
|
|
||||||
if matched:
|
|
||||||
break
|
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
|
||||||
yield self.builder.format_chunk(token)
|
|
||||||
yielded += token
|
|
||||||
|
|
||||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
|
||||||
for event in self.builder.format_stream_end(ctx, stop):
|
|
||||||
yield event
|
|
||||||
yield sse_done()
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_stream(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_non_stream(
|
|
||||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
checker = StopChecker(stop_sequences)
|
|
||||||
chunks: List[str] = []
|
|
||||||
body = ""
|
|
||||||
matched = None
|
|
||||||
|
|
||||||
async for token in agen:
|
|
||||||
chunks.append(token)
|
|
||||||
body += token
|
|
||||||
|
|
||||||
matched = checker.check(body)
|
|
||||||
if matched:
|
|
||||||
break
|
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
|
||||||
|
|
||||||
content = "".join(chunks)
|
|
||||||
stop = StopInfo(matched=matched, body=body)
|
|
||||||
return self.builder.format_response(ctx, content, stop)
|
|
||||||
|
|
@ -1,169 +0,0 @@
|
||||||
"""
|
|
||||||
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
|
||||||
|
|
||||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
|
||||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.api.protocol import ProtocolHandler
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
"""OpenAI Chat Completion API request body."""
|
|
||||||
|
|
||||||
model: str = "astrai"
|
|
||||||
messages: List[ChatMessage]
|
|
||||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
|
||||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
||||||
top_k: Optional[int] = Field(default=50, ge=1)
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
|
||||||
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
|
||||||
n: Optional[int] = Field(default=1, ge=1)
|
|
||||||
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
||||||
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
||||||
logit_bias: Optional[Dict[int, float]] = None
|
|
||||||
user: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: Union[str, List[Dict[str, Any]]]
|
|
||||||
|
|
||||||
|
|
||||||
class MessagesRequest(BaseModel):
|
|
||||||
"""Anthropic Messages API request body."""
|
|
||||||
|
|
||||||
model: str = "astrai"
|
|
||||||
max_tokens: int = Field(default=1024, ge=1)
|
|
||||||
messages: List[AnthropicMessage]
|
|
||||||
system: Optional[str] = None
|
|
||||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
|
||||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
||||||
top_k: Optional[int] = Field(default=50, ge=1)
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
stop_sequences: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
config = app.state.server_config
|
|
||||||
if not config.get("_test", False):
|
|
||||||
try:
|
|
||||||
app.state.engine = _create_engine(**config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model: {e}")
|
|
||||||
raise
|
|
||||||
yield
|
|
||||||
if app.state.engine:
|
|
||||||
app.state.engine.shutdown()
|
|
||||||
logger.info("Inference engine shutdown complete")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_engine(
|
|
||||||
param_path: Optional[Path] = None,
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
) -> InferenceEngine:
|
|
||||||
if param_path is None:
|
|
||||||
param_path = _project_root / "params"
|
|
||||||
if not param_path.exists():
|
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
|
||||||
model = AutoModel.from_pretrained(param_path)
|
|
||||||
model.to(device=device, dtype=dtype)
|
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
|
||||||
|
|
||||||
engine = InferenceEngine(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_batch_size=max_batch_size,
|
|
||||||
)
|
|
||||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
def _get_engine() -> InferenceEngine:
|
|
||||||
engine = app.state.engine
|
|
||||||
if engine is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health():
|
|
||||||
return {
|
|
||||||
"status": "ok",
|
|
||||||
"model_loaded": app.state.engine is not None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
|
||||||
async def get_stats():
|
|
||||||
return _get_engine().get_stats()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
|
||||||
engine = _get_engine()
|
|
||||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
|
||||||
return await handler.handle()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
|
||||||
async def create_message(request: MessagesRequest):
|
|
||||||
engine = _get_engine()
|
|
||||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
|
||||||
return await handler.handle()
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
|
||||||
host: str = "0.0.0.0",
|
|
||||||
port: int = 8000,
|
|
||||||
reload: bool = False,
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
param_path: Optional[Path] = None,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
):
|
|
||||||
app.state.server_config = {
|
|
||||||
"device": device,
|
|
||||||
"dtype": dtype,
|
|
||||||
"param_path": param_path,
|
|
||||||
"max_batch_size": max_batch_size,
|
|
||||||
}
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
reload=reload,
|
|
||||||
)
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
"""Inference core: cache, executor, scheduler, task management."""
|
|
||||||
|
|
||||||
from astrai.inference.core.cache import (
|
|
||||||
Allocator,
|
|
||||||
KVCache,
|
|
||||||
KvcacheView,
|
|
||||||
PagePool,
|
|
||||||
PrefixCache,
|
|
||||||
Storage,
|
|
||||||
TaskTable,
|
|
||||||
page_hash,
|
|
||||||
)
|
|
||||||
from astrai.inference.core.executor import Executor
|
|
||||||
from astrai.inference.core.scheduler import InferenceScheduler
|
|
||||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Allocator",
|
|
||||||
"KVCache",
|
|
||||||
"KvcacheView",
|
|
||||||
"PagePool",
|
|
||||||
"PrefixCache",
|
|
||||||
"Storage",
|
|
||||||
"TaskTable",
|
|
||||||
"page_hash",
|
|
||||||
"Executor",
|
|
||||||
"InferenceScheduler",
|
|
||||||
"STOP",
|
|
||||||
"Task",
|
|
||||||
"TaskManager",
|
|
||||||
"TaskStatus",
|
|
||||||
]
|
|
||||||
|
|
@ -1,368 +0,0 @@
|
||||||
import threading
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
|
||||||
start = page_idx * page_size
|
|
||||||
end = min(start + page_size, len(token_ids))
|
|
||||||
h = 0
|
|
||||||
for i in range(start, end):
|
|
||||||
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class Allocator:
|
|
||||||
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
|
|
||||||
|
|
||||||
def __init__(self, n_pages: int):
|
|
||||||
self._free_mask = (1 << n_pages) - 1
|
|
||||||
self._refs: List[int] = [0] * n_pages
|
|
||||||
self._lru: OrderedDict[int, None] = OrderedDict()
|
|
||||||
self.on_evict: Optional[Callable[[int], None]] = None
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def alloc(self) -> int:
|
|
||||||
with self._lock:
|
|
||||||
if self._free_mask:
|
|
||||||
lsb = self._free_mask & -self._free_mask
|
|
||||||
idx = lsb.bit_length() - 1
|
|
||||||
self._free_mask ^= lsb
|
|
||||||
self._refs[idx] = 1
|
|
||||||
return idx
|
|
||||||
if self._lru:
|
|
||||||
idx, _ = self._lru.popitem(last=False)
|
|
||||||
if self.on_evict:
|
|
||||||
self.on_evict(idx)
|
|
||||||
self._refs[idx] = 1
|
|
||||||
self._free_mask &= ~(1 << idx)
|
|
||||||
return idx
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def free(self, idx: int, keep_cached: bool = False):
|
|
||||||
with self._lock:
|
|
||||||
self._refs[idx] -= 1
|
|
||||||
if self._refs[idx] == 0:
|
|
||||||
if keep_cached:
|
|
||||||
self._lru[idx] = None
|
|
||||||
else:
|
|
||||||
self._free_mask |= 1 << idx
|
|
||||||
|
|
||||||
def inc_ref(self, idx: int):
|
|
||||||
with self._lock:
|
|
||||||
self._refs[idx] += 1
|
|
||||||
self._lru.pop(idx, None)
|
|
||||||
|
|
||||||
def ref_count(self, idx: int) -> int:
|
|
||||||
with self._lock:
|
|
||||||
return self._refs[idx]
|
|
||||||
|
|
||||||
def touch(self, idx: int):
|
|
||||||
with self._lock:
|
|
||||||
self._lru.move_to_end(idx)
|
|
||||||
|
|
||||||
|
|
||||||
class PrefixCache:
|
|
||||||
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
|
||||||
|
|
||||||
def __init__(self, page_size: int):
|
|
||||||
self._page_size = page_size
|
|
||||||
self._page_to_hash: Dict[int, int] = {}
|
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def evict(self, idx: int):
|
|
||||||
with self._lock:
|
|
||||||
h = self._page_to_hash.pop(idx, None)
|
|
||||||
if h is not None:
|
|
||||||
self._hash_to_page.pop(h, None)
|
|
||||||
|
|
||||||
def has_page(self, idx: int) -> bool:
|
|
||||||
with self._lock:
|
|
||||||
return idx in self._page_to_hash
|
|
||||||
|
|
||||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
|
||||||
with self._lock:
|
|
||||||
full_pages = len(token_ids) // self._page_size
|
|
||||||
hits: List[int] = []
|
|
||||||
for i in range(full_pages):
|
|
||||||
h = page_hash(token_ids, i, self._page_size)
|
|
||||||
p = self._hash_to_page.get(h)
|
|
||||||
if p is None:
|
|
||||||
break
|
|
||||||
hits.append(p)
|
|
||||||
return hits
|
|
||||||
|
|
||||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
|
||||||
with self._lock:
|
|
||||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
|
||||||
old_h = self._page_to_hash.pop(page_idx, None)
|
|
||||||
if old_h is not None:
|
|
||||||
self._hash_to_page.pop(old_h, None)
|
|
||||||
self._page_to_hash[page_idx] = h
|
|
||||||
self._hash_to_page[h] = page_idx
|
|
||||||
|
|
||||||
|
|
||||||
class PagePool:
|
|
||||||
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
|
|
||||||
|
|
||||||
def __init__(self, allocator: Allocator, prefix: PrefixCache):
|
|
||||||
self._alloc = allocator
|
|
||||||
self._prefix = prefix
|
|
||||||
self._alloc.on_evict = prefix.evict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def allocator(self) -> Allocator:
|
|
||||||
return self._alloc
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prefix(self) -> PrefixCache:
|
|
||||||
return self._prefix
|
|
||||||
|
|
||||||
def alloc(self) -> int:
|
|
||||||
return self._alloc.alloc()
|
|
||||||
|
|
||||||
def free(self, idx: int):
|
|
||||||
keep = self._prefix.has_page(idx)
|
|
||||||
self._alloc.free(idx, keep_cached=keep)
|
|
||||||
if not keep:
|
|
||||||
self._prefix.evict(idx)
|
|
||||||
|
|
||||||
def inc_ref(self, idx: int):
|
|
||||||
self._alloc.inc_ref(idx)
|
|
||||||
|
|
||||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
|
||||||
hits = self._prefix.lookup(token_ids)
|
|
||||||
for p in hits:
|
|
||||||
self._alloc.touch(p)
|
|
||||||
return hits
|
|
||||||
|
|
||||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
|
||||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskTable:
|
|
||||||
"""Maps task_ids to page tables and cached token counts."""
|
|
||||||
|
|
||||||
def __init__(self, page_size: int):
|
|
||||||
self._page_size = page_size
|
|
||||||
self._pages: Dict[str, List[int]] = {}
|
|
||||||
self._cached: Dict[str, int] = {}
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def set(self, task_id: str, page_table: List[int], cached: int):
|
|
||||||
with self._lock:
|
|
||||||
self._pages[task_id] = page_table
|
|
||||||
self._cached[task_id] = cached
|
|
||||||
|
|
||||||
def get(self, task_id: str) -> List[int]:
|
|
||||||
with self._lock:
|
|
||||||
return self._pages.get(task_id, [])
|
|
||||||
|
|
||||||
def get_cached(self, task_id: str) -> int:
|
|
||||||
with self._lock:
|
|
||||||
return self._cached.get(task_id, 0)
|
|
||||||
|
|
||||||
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
|
||||||
with self._lock:
|
|
||||||
pages = self._pages.pop(task_id, [])
|
|
||||||
cached = self._cached.pop(task_id, 0)
|
|
||||||
return pages, cached
|
|
||||||
|
|
||||||
def get_ref(self, task_id: str) -> List[int]:
|
|
||||||
with self._lock:
|
|
||||||
return self._pages.setdefault(task_id, [])
|
|
||||||
|
|
||||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
|
||||||
with self._lock:
|
|
||||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
|
||||||
max_pages = max((len(s) for s in states), default=0)
|
|
||||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
|
||||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
|
||||||
"""KV-cache tensor storage with paged write/gather."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_layers: int,
|
|
||||||
n_pages: int,
|
|
||||||
page_size: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
self.page_size = page_size
|
|
||||||
self.k_cache = torch.empty(
|
|
||||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.v_cache = torch.empty(
|
|
||||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def write(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
page_table: Tensor,
|
|
||||||
start_pos: int,
|
|
||||||
k: Tensor,
|
|
||||||
v: Tensor,
|
|
||||||
):
|
|
||||||
seq_len = k.size(1)
|
|
||||||
if seq_len == 0:
|
|
||||||
return
|
|
||||||
page_size = self.page_size
|
|
||||||
written = 0
|
|
||||||
first_page = start_pos // page_size
|
|
||||||
last_page = (start_pos + seq_len - 1) // page_size
|
|
||||||
for pi in range(first_page, last_page + 1):
|
|
||||||
phys_pages = page_table[:, pi]
|
|
||||||
page_start = pi * page_size
|
|
||||||
write_start = max(page_start, start_pos)
|
|
||||||
write_end = min(page_start + page_size, start_pos + seq_len)
|
|
||||||
offset = write_start - page_start
|
|
||||||
chunk = write_end - write_start
|
|
||||||
valid = phys_pages >= 0
|
|
||||||
if not valid.all():
|
|
||||||
if valid.any():
|
|
||||||
valid_pages = phys_pages[valid]
|
|
||||||
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
|
|
||||||
valid, written : written + chunk
|
|
||||||
]
|
|
||||||
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
|
|
||||||
valid, written : written + chunk
|
|
||||||
]
|
|
||||||
written += chunk
|
|
||||||
continue
|
|
||||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
|
||||||
:, written : written + chunk
|
|
||||||
]
|
|
||||||
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
|
||||||
:, written : written + chunk
|
|
||||||
]
|
|
||||||
written += chunk
|
|
||||||
|
|
||||||
def gather(
|
|
||||||
self, layer_id: int, page_table: Tensor, total_len: int
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
safe = page_table.clamp(min=0)
|
|
||||||
k = self.k_cache[layer_id, safe]
|
|
||||||
v = self.v_cache[layer_id, safe]
|
|
||||||
k = k.flatten(1, 2)
|
|
||||||
v = v.flatten(1, 2)
|
|
||||||
if (page_table < 0).any():
|
|
||||||
invalid = (
|
|
||||||
(page_table < 0)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand(-1, -1, self.page_size)
|
|
||||||
.flatten(1, 2)
|
|
||||||
)
|
|
||||||
invalid = invalid[:, :, None, None].expand_as(k)
|
|
||||||
k = k.masked_fill(invalid, 0.0)
|
|
||||||
v = v.masked_fill(invalid, 0.0)
|
|
||||||
k = k[:, :total_len]
|
|
||||||
v = v[:, :total_len]
|
|
||||||
return k, v
|
|
||||||
|
|
||||||
|
|
||||||
class KvcacheView:
|
|
||||||
"""Bundles Storage + page_table + total_len for attention layers."""
|
|
||||||
|
|
||||||
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
|
|
||||||
self._storage = storage
|
|
||||||
self._page_table = page_table
|
|
||||||
self._total_len = total_len
|
|
||||||
|
|
||||||
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
|
||||||
start_pos = self._total_len - k.size(1)
|
|
||||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
|
||||||
|
|
||||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
|
||||||
return self._storage.gather(layer_id, self._page_table, self._total_len)
|
|
||||||
|
|
||||||
|
|
||||||
class KVCache:
|
|
||||||
"""Facade: page management + KV-cache I/O for continuous batching."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_layers: int,
|
|
||||||
n_pages: int,
|
|
||||||
page_size: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
self.page_size = page_size
|
|
||||||
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
|
|
||||||
self._table = TaskTable(page_size)
|
|
||||||
self._storage = Storage(
|
|
||||||
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
|
||||||
hits = self._pool.lookup(prompt_ids)
|
|
||||||
cached = len(hits) * self.page_size
|
|
||||||
for p in hits:
|
|
||||||
self._pool.inc_ref(p)
|
|
||||||
|
|
||||||
remaining = len(prompt_ids) - cached
|
|
||||||
n_new = (
|
|
||||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
|
||||||
)
|
|
||||||
new_pages: List[int] = []
|
|
||||||
if n_new > 0:
|
|
||||||
for _ in range(n_new):
|
|
||||||
p = self._pool.alloc()
|
|
||||||
if p < 0:
|
|
||||||
for hp in hits:
|
|
||||||
self._pool.free(hp)
|
|
||||||
for np in new_pages:
|
|
||||||
self._pool.free(np)
|
|
||||||
return False
|
|
||||||
new_pages.append(p)
|
|
||||||
|
|
||||||
self._table.set(task_id, hits + new_pages, cached)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def task_free(self, task_id: str):
|
|
||||||
page_table, _ = self._table.pop(task_id)
|
|
||||||
for idx in page_table:
|
|
||||||
self._pool.free(idx)
|
|
||||||
|
|
||||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
|
||||||
page_table = self._table.get(task_id)
|
|
||||||
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
|
||||||
while len(page_table) < needed:
|
|
||||||
p = self._pool.alloc()
|
|
||||||
if p < 0:
|
|
||||||
return False
|
|
||||||
page_table.append(p)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def task_cached(self, task_id: str) -> int:
|
|
||||||
return self._table.get_cached(task_id)
|
|
||||||
|
|
||||||
def task_record_hashes(
|
|
||||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
|
||||||
):
|
|
||||||
page_table = self._table.get(task_id)
|
|
||||||
full_pages = len(prompt_ids) // self.page_size
|
|
||||||
for i in range(start_logical_page, full_pages):
|
|
||||||
self._pool.record(page_table[i], prompt_ids, i)
|
|
||||||
|
|
||||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
|
||||||
return self._table.table_tensor(task_ids, device)
|
|
||||||
|
|
||||||
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
|
|
||||||
return KvcacheView(self._storage, page_table, total_len)
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from astrai.inference.core.cache import KVCache
|
|
||||||
from astrai.inference.core.task import Task
|
|
||||||
from astrai.inference.sample import sample
|
|
||||||
from astrai.model.automodel import AutoModel
|
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Executor:
|
|
||||||
"""Model forward passes for prefill and decode phases."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: AutoModel,
|
|
||||||
tokenizer: AutoTokenizer,
|
|
||||||
page_cache: KVCache,
|
|
||||||
device: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.page_cache = page_cache
|
|
||||||
self.device = device or next(model.parameters()).device
|
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
|
||||||
|
|
||||||
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
|
||||||
if start_pos >= prompt_len:
|
|
||||||
return
|
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
||||||
batch_sz = len(tasks)
|
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
|
||||||
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
task_ids = [t.task_id for t in tasks]
|
|
||||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
self.model(
|
|
||||||
input_ids,
|
|
||||||
position_ids=torch.arange(
|
|
||||||
start_pos, prompt_len, dtype=torch.long, device=self.device
|
|
||||||
)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(batch_sz, -1),
|
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
|
||||||
if not tasks:
|
|
||||||
return []
|
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
|
||||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids = torch.tensor(
|
|
||||||
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
|
|
||||||
)
|
|
||||||
total_len = position_ids.max().item() + 1
|
|
||||||
|
|
||||||
task_ids = [t.task_id for t in tasks]
|
|
||||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
||||||
|
|
||||||
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
|
||||||
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
|
||||||
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids.unsqueeze(1),
|
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
|
||||||
position_ids=position_ids.unsqueeze(1),
|
|
||||||
)
|
|
||||||
logits = outputs["logits"][:, -1, :]
|
|
||||||
|
|
||||||
return sample(
|
|
||||||
logits,
|
|
||||||
temperature=temperatures,
|
|
||||||
top_k=top_ks,
|
|
||||||
top_p=top_ps,
|
|
||||||
).tolist()
|
|
||||||
|
|
@ -1,212 +0,0 @@
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from astrai.inference.core.cache import KVCache
|
|
||||||
from astrai.inference.core.executor import Executor
|
|
||||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
|
||||||
from astrai.model.automodel import AutoModel
|
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceScheduler:
|
|
||||||
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: AutoModel,
|
|
||||||
tokenizer: AutoTokenizer,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
max_seq_len: Optional[int] = None,
|
|
||||||
max_prompt_len: int = 2048,
|
|
||||||
page_size: int = 64,
|
|
||||||
device: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
config = model.config
|
|
||||||
|
|
||||||
if max_seq_len is not None:
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
elif config.max_len is not None:
|
|
||||||
self.max_seq_len = config.max_len
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"max_seq_len must be provided either as argument "
|
|
||||||
"or in model config (config.max_len)"
|
|
||||||
)
|
|
||||||
self.device = device or next(model.parameters()).device
|
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
|
||||||
|
|
||||||
n_pages = (
|
|
||||||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
|
||||||
) // page_size
|
|
||||||
|
|
||||||
self._page_cache = KVCache(
|
|
||||||
config.n_layers,
|
|
||||||
n_pages,
|
|
||||||
page_size,
|
|
||||||
config.n_kv_heads,
|
|
||||||
config.dim // config.n_heads,
|
|
||||||
self.device,
|
|
||||||
self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._task_mgr = TaskManager(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_batch_size=max_batch_size,
|
|
||||||
max_seq_len=self.max_seq_len,
|
|
||||||
max_prompt_len=max_prompt_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._executor = Executor(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
page_cache=self._page_cache,
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._running = False
|
|
||||||
self._fatal_error: Optional[Exception] = None
|
|
||||||
|
|
||||||
def add_task(self, prompt: str, **kwargs) -> str:
|
|
||||||
return self._task_mgr.add_task(prompt, **kwargs)
|
|
||||||
|
|
||||||
def remove_task(self, task_id: str):
|
|
||||||
for task in self._task_mgr.remove_task(task_id):
|
|
||||||
self._page_cache.task_free(task.task_id)
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
return self._task_mgr.get_stats()
|
|
||||||
|
|
||||||
def _run_generation_loop(self):
|
|
||||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
|
||||||
try:
|
|
||||||
while self._running:
|
|
||||||
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
|
||||||
for task in finished:
|
|
||||||
self._page_cache.task_free(task.task_id)
|
|
||||||
|
|
||||||
active = self._task_mgr.get_active_tasks()
|
|
||||||
available = self._task_mgr.max_batch_size - len(active)
|
|
||||||
if available > 0:
|
|
||||||
candidates = self._task_mgr.pull_candidates(available)
|
|
||||||
failed = []
|
|
||||||
for task in candidates:
|
|
||||||
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
|
|
||||||
self._task_mgr.activate(task)
|
|
||||||
else:
|
|
||||||
failed.append(task)
|
|
||||||
if failed:
|
|
||||||
self._task_mgr.return_to_waiting(failed)
|
|
||||||
|
|
||||||
if not self._task_mgr.has_work():
|
|
||||||
self._task_mgr.wait_for_tasks(timeout=1.0)
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_prefill = [
|
|
||||||
t
|
|
||||||
for t in self._task_mgr.get_active_tasks()
|
|
||||||
if t.output_tokens == 0
|
|
||||||
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
|
||||||
]
|
|
||||||
if to_prefill:
|
|
||||||
for t in to_prefill:
|
|
||||||
t.input_tokens = len(t.prompt_ids)
|
|
||||||
|
|
||||||
groups: Dict[Tuple[int, int], List[Task]] = {}
|
|
||||||
for t in to_prefill:
|
|
||||||
key = (
|
|
||||||
len(t.prompt_ids),
|
|
||||||
self._page_cache.task_cached(t.task_id),
|
|
||||||
)
|
|
||||||
groups.setdefault(key, []).append(t)
|
|
||||||
|
|
||||||
for (prompt_len, start_pos), group in groups.items():
|
|
||||||
self._executor.execute_prefill(group, prompt_len, start_pos)
|
|
||||||
start_logical_page = start_pos // self._page_cache.page_size
|
|
||||||
for t in group:
|
|
||||||
self._page_cache.task_record_hashes(
|
|
||||||
t.task_id,
|
|
||||||
t.prompt_ids,
|
|
||||||
start_logical_page=start_logical_page,
|
|
||||||
)
|
|
||||||
|
|
||||||
pos_groups: Dict[int, List[Task]] = {}
|
|
||||||
for t in self._task_mgr.get_active_tasks():
|
|
||||||
pos_groups.setdefault(t.next_pos, []).append(t)
|
|
||||||
|
|
||||||
if pos_groups:
|
|
||||||
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
|
||||||
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
|
|
||||||
|
|
||||||
valid: List[Task] = []
|
|
||||||
for t in group:
|
|
||||||
if self._page_cache.task_extend(t.task_id, t.next_pos):
|
|
||||||
valid.append(t)
|
|
||||||
else:
|
|
||||||
t.status = TaskStatus.ABORTED
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
||||||
if valid:
|
|
||||||
next_tokens = self._executor.execute_decode(valid)
|
|
||||||
|
|
||||||
for t, ntok in zip(valid, next_tokens):
|
|
||||||
t.output_ids.append(ntok)
|
|
||||||
t.output_tokens += 1
|
|
||||||
pos = t.input_tokens + t.output_tokens
|
|
||||||
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(
|
|
||||||
self._task_mgr.tokenizer.decode([ntok])
|
|
||||||
)
|
|
||||||
if not extend_ok:
|
|
||||||
t.status = TaskStatus.ABORTED
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
||||||
for t in valid:
|
|
||||||
if t.is_finished(stop_ids):
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._fatal_error = e
|
|
||||||
self._running = False
|
|
||||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
|
||||||
for task in self._task_mgr.get_active_tasks():
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._page_cache.task_free(task.task_id)
|
|
||||||
for task in self._task_mgr.get_waiting_tasks():
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._task_mgr.clear_queues()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if not self._running:
|
|
||||||
self._running = True
|
|
||||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
|
||||||
t.start()
|
|
||||||
self._loop_thread = t
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._running = False
|
|
||||||
self._task_mgr.wake()
|
|
||||||
if hasattr(self, "_loop_thread"):
|
|
||||||
self._loop_thread.join(timeout=2.0)
|
|
||||||
for task in self._task_mgr.get_active_tasks():
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._page_cache.task_free(task.task_id)
|
|
||||||
for task in self._task_mgr.get_waiting_tasks():
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._task_mgr.clear_queues()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
@ -1,209 +0,0 @@
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from collections import deque
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, Deque, Dict, List, Optional
|
|
||||||
|
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
STOP = object()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(Enum):
|
|
||||||
"""Task lifecycle states."""
|
|
||||||
|
|
||||||
PENDING = "pending"
|
|
||||||
RUNNING = "running"
|
|
||||||
FINISHED = "finished"
|
|
||||||
ABORTED = "aborted"
|
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
|
||||||
"""Single generation request: prompt, sampling params, output state."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
prompt_ids: List[int],
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
|
||||||
):
|
|
||||||
self.task_id = task_id
|
|
||||||
self.prompt_ids = prompt_ids
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.temperature = temperature
|
|
||||||
self.top_p = top_p
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
self.status = TaskStatus.PENDING
|
|
||||||
self.output_ids: List[int] = []
|
|
||||||
self.input_tokens: int = 0
|
|
||||||
self.output_tokens: int = 0
|
|
||||||
self.arrival_time = time.time()
|
|
||||||
self.finish_time: Optional[float] = None
|
|
||||||
self.stream_callback = stream_callback
|
|
||||||
|
|
||||||
@property
|
|
||||||
def next_pos(self) -> int:
|
|
||||||
return self.input_tokens + len(self.output_ids)
|
|
||||||
|
|
||||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
|
||||||
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
|
|
||||||
return True
|
|
||||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
|
||||||
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer: AutoTokenizer,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
max_seq_len: int = 8192,
|
|
||||||
max_prompt_len: int = 512,
|
|
||||||
):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.max_batch_size = max_batch_size
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
self.max_prompt_len = max_prompt_len
|
|
||||||
|
|
||||||
self.waiting_queue: Deque[Task] = deque()
|
|
||||||
self.active_tasks: List[Task] = []
|
|
||||||
|
|
||||||
self._task_event = threading.Event()
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
self._total_tasks = 0
|
|
||||||
self._total_tokens = 0
|
|
||||||
|
|
||||||
def add_task(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
|
||||||
) -> str:
|
|
||||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
|
||||||
prompt_ids = self.tokenizer.encode(prompt)
|
|
||||||
if len(prompt_ids) > self.max_prompt_len:
|
|
||||||
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
|
||||||
|
|
||||||
if len(prompt_ids) >= self.max_seq_len:
|
|
||||||
if stream_callback:
|
|
||||||
stream_callback(STOP)
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
if max_tokens is None:
|
|
||||||
max_tokens = self.max_seq_len - len(prompt_ids)
|
|
||||||
else:
|
|
||||||
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
|
||||||
|
|
||||||
task = Task(
|
|
||||||
task_id=task_id,
|
|
||||||
prompt_ids=prompt_ids,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
stream_callback=stream_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
self.waiting_queue.append(task)
|
|
||||||
self._total_tasks += 1
|
|
||||||
|
|
||||||
self._task_event.set()
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> List[Task]:
|
|
||||||
with self._lock:
|
|
||||||
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
|
||||||
self.waiting_queue = deque(
|
|
||||||
t for t in self.waiting_queue if t.task_id != task_id
|
|
||||||
)
|
|
||||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
|
||||||
return removed_active
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"total_tasks": self._total_tasks,
|
|
||||||
"total_tokens": self._total_tokens,
|
|
||||||
"active_tasks": len(self.active_tasks),
|
|
||||||
"waiting_queue": len(self.waiting_queue),
|
|
||||||
}
|
|
||||||
|
|
||||||
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
|
||||||
with self._lock:
|
|
||||||
finished = []
|
|
||||||
for task in self.active_tasks:
|
|
||||||
if task.status == TaskStatus.ABORTED:
|
|
||||||
task.finish_time = time.time()
|
|
||||||
finished.append(task)
|
|
||||||
elif task.is_finished(stop_ids):
|
|
||||||
task.status = TaskStatus.FINISHED
|
|
||||||
task.finish_time = time.time()
|
|
||||||
finished.append(task)
|
|
||||||
self._total_tokens += task.output_tokens
|
|
||||||
|
|
||||||
self.active_tasks = [
|
|
||||||
t
|
|
||||||
for t in self.active_tasks
|
|
||||||
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
|
||||||
]
|
|
||||||
return finished
|
|
||||||
|
|
||||||
def pull_candidates(self, n: int) -> List[Task]:
|
|
||||||
to_add: List[Task] = []
|
|
||||||
with self._lock:
|
|
||||||
take = min(n, len(self.waiting_queue))
|
|
||||||
for _ in range(take):
|
|
||||||
to_add.append(self.waiting_queue.popleft())
|
|
||||||
return to_add
|
|
||||||
|
|
||||||
def activate(self, task: Task):
|
|
||||||
task.status = TaskStatus.RUNNING
|
|
||||||
with self._lock:
|
|
||||||
self.active_tasks.append(task)
|
|
||||||
|
|
||||||
def return_to_waiting(self, tasks: List[Task]):
|
|
||||||
with self._lock:
|
|
||||||
for task in reversed(tasks):
|
|
||||||
self.waiting_queue.appendleft(task)
|
|
||||||
|
|
||||||
def has_work(self) -> bool:
|
|
||||||
return bool(self.active_tasks or self.waiting_queue)
|
|
||||||
|
|
||||||
def wait_for_tasks(self, timeout: float = 1.0):
|
|
||||||
with self._lock:
|
|
||||||
if self.waiting_queue or self.active_tasks:
|
|
||||||
return
|
|
||||||
self._task_event.clear()
|
|
||||||
self._task_event.wait(timeout=timeout)
|
|
||||||
|
|
||||||
def get_active_tasks(self) -> List[Task]:
|
|
||||||
with self._lock:
|
|
||||||
return list(self.active_tasks)
|
|
||||||
|
|
||||||
def get_waiting_tasks(self) -> List[Task]:
|
|
||||||
with self._lock:
|
|
||||||
return list(self.waiting_queue)
|
|
||||||
|
|
||||||
def clear_queues(self):
|
|
||||||
with self._lock:
|
|
||||||
self.waiting_queue.clear()
|
|
||||||
self.active_tasks.clear()
|
|
||||||
|
|
||||||
def wake(self):
|
|
||||||
self._task_event.set()
|
|
||||||
|
|
@ -1,288 +0,0 @@
|
||||||
"""Unified inference engine for continuous batching."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import gc
|
|
||||||
import threading
|
|
||||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from astrai.inference.core.scheduler import InferenceScheduler
|
|
||||||
from astrai.inference.core.task import STOP
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateResult:
|
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
|
||||||
|
|
||||||
def __init__(self, count: int = 1):
|
|
||||||
self._cond = threading.Condition()
|
|
||||||
self._event = threading.Event()
|
|
||||||
self.tokens: List[Tuple[int, str]] = []
|
|
||||||
self.results: List[str] = [""] * count
|
|
||||||
self._done: List[bool] = [False] * count
|
|
||||||
self._completed = 0
|
|
||||||
self._total = count
|
|
||||||
|
|
||||||
def append(self, token: str, idx: int = 0):
|
|
||||||
with self._cond:
|
|
||||||
self.tokens.append((idx, token))
|
|
||||||
if token is not STOP:
|
|
||||||
self.results[idx] += token
|
|
||||||
else:
|
|
||||||
if not self._done[idx]:
|
|
||||||
self._done[idx] = True
|
|
||||||
self._completed += 1
|
|
||||||
self._cond.notify_all()
|
|
||||||
self._event.set()
|
|
||||||
|
|
||||||
def pop_all(self) -> List[Tuple[int, str]]:
|
|
||||||
with self._cond:
|
|
||||||
out = self.tokens.copy()
|
|
||||||
self.tokens.clear()
|
|
||||||
if not out:
|
|
||||||
self._event.clear()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
|
||||||
return self._event.wait(timeout=timeout)
|
|
||||||
|
|
||||||
def wait_completion(self, timeout: float = 300.0):
|
|
||||||
with self._cond:
|
|
||||||
if not self._cond.wait_for(
|
|
||||||
lambda: self._completed >= self._total, timeout=timeout
|
|
||||||
):
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Generation timeout after {timeout}s "
|
|
||||||
f"({self._completed}/{self._total} completed)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_results(self) -> List[str]:
|
|
||||||
with self._cond:
|
|
||||||
return self.results.copy()
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
|
||||||
"""Request parameters for text generation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
top_k: int = 50,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
):
|
|
||||||
if not (isinstance(top_k, int) and top_k >= 0):
|
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
|
||||||
raise ValueError("temperature must be a positive number")
|
|
||||||
|
|
||||||
self.messages = messages
|
|
||||||
self.top_k = top_k
|
|
||||||
self.top_p = top_p
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.stream = stream
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
|
||||||
"""Unified inference engine backed by continuous-batching scheduler."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
tokenizer: AutoTokenizer,
|
|
||||||
max_batch_size: int = 1,
|
|
||||||
max_seq_len: Optional[int] = None,
|
|
||||||
max_prompt_len: int = 2048,
|
|
||||||
page_size: int = 128,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.scheduler = InferenceScheduler(
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
max_batch_size=max_batch_size,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
max_prompt_len=max_prompt_len,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scheduler.start()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.shutdown()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str]],
|
|
||||||
stream: bool = False,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
) -> Union[Generator, str, List[str]]:
|
|
||||||
is_batch = isinstance(prompt, list)
|
|
||||||
prompts = prompt if is_batch else [prompt]
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._generate_streaming(
|
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._generate_non_streaming(
|
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_async(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
sync_gen = self._generate_streaming(
|
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _agen():
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
while True:
|
|
||||||
token = await loop.run_in_executor(None, self._next_token, sync_gen)
|
|
||||||
if token is None:
|
|
||||||
break
|
|
||||||
yield token
|
|
||||||
|
|
||||||
return _agen()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _next_token(gen: Generator) -> Optional[str]:
|
|
||||||
try:
|
|
||||||
return next(gen)
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def generate_with_request(
|
|
||||||
self, request: GenerationRequest
|
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
|
||||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
|
||||||
return self.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
stream=request.stream,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _submit_tasks(
|
|
||||||
self,
|
|
||||||
prompts: List[str],
|
|
||||||
max_tokens: Optional[int],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
) -> Tuple[GenerateResult, List[str]]:
|
|
||||||
n = len(prompts)
|
|
||||||
result = GenerateResult(count=n)
|
|
||||||
task_ids = []
|
|
||||||
for i, p in enumerate(prompts):
|
|
||||||
cb = self._make_callback(result, i)
|
|
||||||
task_id = self.scheduler.add_task(
|
|
||||||
prompt=p,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
stream_callback=cb,
|
|
||||||
)
|
|
||||||
task_ids.append(task_id)
|
|
||||||
return result, task_ids
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_callback(result: GenerateResult, idx: int):
|
|
||||||
def cb(token):
|
|
||||||
result.append(token, idx)
|
|
||||||
|
|
||||||
return cb
|
|
||||||
|
|
||||||
def _generate_streaming(
|
|
||||||
self,
|
|
||||||
prompts: List[str],
|
|
||||||
is_batch: bool,
|
|
||||||
max_tokens: Optional[int],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
) -> Generator:
|
|
||||||
result, task_ids = self._submit_tasks(
|
|
||||||
prompts, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
n = len(prompts)
|
|
||||||
remaining = n
|
|
||||||
finished = [False] * n
|
|
||||||
|
|
||||||
def gen():
|
|
||||||
nonlocal remaining
|
|
||||||
try:
|
|
||||||
while remaining > 0:
|
|
||||||
items = result.pop_all()
|
|
||||||
for idx, token in items:
|
|
||||||
if token is STOP:
|
|
||||||
if not finished[idx]:
|
|
||||||
finished[idx] = True
|
|
||||||
remaining -= 1
|
|
||||||
else:
|
|
||||||
yield (idx, token) if is_batch else token
|
|
||||||
if remaining > 0:
|
|
||||||
result.wait(timeout=0.05)
|
|
||||||
finally:
|
|
||||||
for tid in task_ids:
|
|
||||||
self.scheduler.remove_task(tid)
|
|
||||||
|
|
||||||
return gen()
|
|
||||||
|
|
||||||
def _generate_non_streaming(
|
|
||||||
self,
|
|
||||||
prompts: List[str],
|
|
||||||
is_batch: bool,
|
|
||||||
max_tokens: Optional[int],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
) -> Union[str, List[str]]:
|
|
||||||
result, task_ids = self._submit_tasks(
|
|
||||||
prompts, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result.wait_completion()
|
|
||||||
except TimeoutError:
|
|
||||||
for tid in task_ids:
|
|
||||||
self.scheduler.remove_task(tid)
|
|
||||||
raise
|
|
||||||
|
|
||||||
for tid in task_ids:
|
|
||||||
self.scheduler.remove_task(tid)
|
|
||||||
|
|
||||||
res = result.get_results()
|
|
||||||
return res if is_batch else res[0]
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
return self.scheduler.get_stats()
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
self.scheduler.stop()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
@ -1,190 +0,0 @@
|
||||||
"""Composable sampling strategies for logit transformation.
|
|
||||||
|
|
||||||
Implements the Strategy pattern: each sampling technique
|
|
||||||
(temperature, top-k, top-p) is a pluggable strategy that
|
|
||||||
can be composed into a pipeline.
|
|
||||||
|
|
||||||
All strategies accept both scalar and per-sample tensor
|
|
||||||
parameters, so a single pipeline works for any batch size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSamplingStrategy(ABC):
|
|
||||||
"""Abstract base for a logit transformation strategy."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
|
||||||
"""Applies the strategy to logits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: Raw logits tensor (batch, vocab_size).
|
|
||||||
filter_value: Value assigned to filtered-out positions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Transformed logits tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class TemperatureStrategy(BaseSamplingStrategy):
|
|
||||||
"""Divides logits by temperature to control randomness.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature: Scalar or ``[batch]`` tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
|
||||||
self.temperature = temperature
|
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
|
||||||
t = self.temperature
|
|
||||||
if isinstance(t, Tensor):
|
|
||||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
|
||||||
t = torch.clamp(t, min=1e-8)
|
|
||||||
if (t != 1.0).any():
|
|
||||||
logits = logits / t
|
|
||||||
elif t != 1.0:
|
|
||||||
logits = logits / max(t, 1e-8)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class TopKStrategy(BaseSamplingStrategy):
|
|
||||||
"""Keeps only the top-k logits, setting the rest to filter_value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
top_k: Scalar or ``[batch]`` tensor (0 disables).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, top_k: Union[int, Tensor] = 0):
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
|
||||||
tk = self.top_k
|
|
||||||
if isinstance(tk, Tensor):
|
|
||||||
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
|
||||||
max_k = int(tk.max().item())
|
|
||||||
if max_k <= 0:
|
|
||||||
return logits
|
|
||||||
max_k = min(max_k, logits.size(-1))
|
|
||||||
values, _ = torch.topk(logits, max_k, dim=-1)
|
|
||||||
per_row_k = tk.clamp(max=max_k)
|
|
||||||
thresholds = torch.full_like(logits[..., -1:], -float("inf"))
|
|
||||||
positive = per_row_k > 0
|
|
||||||
if positive.any():
|
|
||||||
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
|
|
||||||
thresholds[positive] = values[
|
|
||||||
row_idx, per_row_k[positive] - 1
|
|
||||||
].unsqueeze(-1)
|
|
||||||
logits[logits < thresholds] = filter_value
|
|
||||||
return logits
|
|
||||||
if tk > 0:
|
|
||||||
k = min(tk, logits.size(-1))
|
|
||||||
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
|
||||||
logits[logits < thresholds] = filter_value
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class TopPStrategy(BaseSamplingStrategy):
|
|
||||||
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
|
||||||
cumulative probability exceeds top_p.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
|
||||||
self.top_p = top_p
|
|
||||||
|
|
||||||
def _apply(self, logits, top_p, filter_value):
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
||||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
remove = cum_probs > top_p
|
|
||||||
remove[..., 1:] = remove[..., :-1].clone()
|
|
||||||
remove[..., 0] = False
|
|
||||||
mask = torch.zeros_like(logits, dtype=torch.bool)
|
|
||||||
mask.scatter_(1, sorted_indices, remove)
|
|
||||||
logits[mask] = filter_value
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
|
||||||
tp = self.top_p
|
|
||||||
if isinstance(tp, Tensor):
|
|
||||||
tp = tp.to(logits.device, non_blocking=True)
|
|
||||||
if (tp < 1.0).any():
|
|
||||||
logits = self._apply(logits, tp.view(-1, 1), filter_value)
|
|
||||||
elif tp < 1.0:
|
|
||||||
logits = self._apply(logits, tp, filter_value)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class SamplingPipeline(BaseSamplingStrategy):
|
|
||||||
"""Composes multiple sampling strategies into a single transformation.
|
|
||||||
|
|
||||||
Strategies are applied sequentially in the order they are provided,
|
|
||||||
matching the original temperature -> top-k -> top-p ordering.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
pipeline = SamplingPipeline([
|
|
||||||
TemperatureStrategy(0.8),
|
|
||||||
TopKStrategy(50),
|
|
||||||
TopPStrategy(0.95),
|
|
||||||
])
|
|
||||||
logits = pipeline.apply(logits)
|
|
||||||
token = pipeline.sample(logits) # softmax + multinomial
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
|
||||||
self.strategies = strategies
|
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
|
||||||
for strategy in self.strategies:
|
|
||||||
logits = strategy.apply(logits, filter_value)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
|
||||||
"""Apply strategies then sample (softmax + multinomial).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: Raw logits ``[batch, vocab_size]``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sampled token IDs ``[batch]``.
|
|
||||||
"""
|
|
||||||
return torch.multinomial(
|
|
||||||
torch.softmax(self.apply(logits, filter_value), dim=-1),
|
|
||||||
num_samples=1,
|
|
||||||
).squeeze(-1)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def sample(
|
|
||||||
logits: Tensor,
|
|
||||||
temperature: Union[float, Tensor] = 1.0,
|
|
||||||
top_k: Union[int, Tensor] = 0,
|
|
||||||
top_p: Union[float, Tensor] = 1.0,
|
|
||||||
filter_value: float = -float("inf"),
|
|
||||||
) -> Tensor:
|
|
||||||
"""Apply sampling strategies then sample (softmax + multinomial).
|
|
||||||
|
|
||||||
Shortcut for ``SamplingPipeline(...).sample(logits)``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: Raw logits ``[batch, vocab_size]``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sampled token IDs ``[batch]``.
|
|
||||||
"""
|
|
||||||
return SamplingPipeline(
|
|
||||||
[
|
|
||||||
TemperatureStrategy(temperature),
|
|
||||||
TopKStrategy(top_k),
|
|
||||||
TopPStrategy(top_p),
|
|
||||||
]
|
|
||||||
).sample(logits, filter_value)
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
from astrai.model.automodel import AutoModel
|
|
||||||
from astrai.model.components.attention import GQA
|
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
from astrai.model.components.lora import (
|
|
||||||
LoRAConfig,
|
|
||||||
inject_lora,
|
|
||||||
load_lora,
|
|
||||||
merge_lora,
|
|
||||||
save_lora,
|
|
||||||
)
|
|
||||||
from astrai.model.components.mlp import MLP
|
|
||||||
from astrai.model.components.norm import RMSNorm
|
|
||||||
from astrai.model.encoder import EmbeddingEncoder
|
|
||||||
from astrai.model.transformer import AutoRegressiveLM
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Modules
|
|
||||||
"Linear",
|
|
||||||
"RMSNorm",
|
|
||||||
"MLP",
|
|
||||||
"GQA",
|
|
||||||
"DecoderBlock",
|
|
||||||
# Models
|
|
||||||
"AutoRegressiveLM",
|
|
||||||
"EmbeddingEncoder",
|
|
||||||
"AutoModel",
|
|
||||||
# LoRA
|
|
||||||
"LoRAConfig",
|
|
||||||
"inject_lora",
|
|
||||||
"merge_lora",
|
|
||||||
"save_lora",
|
|
||||||
"load_lora",
|
|
||||||
]
|
|
||||||
|
|
@ -1,95 +0,0 @@
|
||||||
"""
|
|
||||||
AutoModel base class for model loading and saving.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Self, Union
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.serialization import load_model_config, load_model_weights, save_model
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _disable_random_init(enable: bool = True):
|
|
||||||
if not enable:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
names = (
|
|
||||||
"xavier_normal_",
|
|
||||||
"xavier_uniform_",
|
|
||||||
"kaiming_normal_",
|
|
||||||
"kaiming_uniform_",
|
|
||||||
"zeros_",
|
|
||||||
"ones_",
|
|
||||||
"constant_",
|
|
||||||
"normal_",
|
|
||||||
"uniform_",
|
|
||||||
)
|
|
||||||
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
|
||||||
for n in orig:
|
|
||||||
setattr(nn.init, n, lambda *a, **kw: None)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
for n, fn in orig.items():
|
|
||||||
setattr(nn.init, n, fn)
|
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|
||||||
"""
|
|
||||||
Autoregressive language model base class.
|
|
||||||
Provides model loading/saving, registration, and generation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: BaseModelConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(
|
|
||||||
cls,
|
|
||||||
path: Union[str, Path],
|
|
||||||
disable_random_init: bool = True,
|
|
||||||
strict: bool = True,
|
|
||||||
) -> nn.Module:
|
|
||||||
|
|
||||||
model_path = Path(path)
|
|
||||||
|
|
||||||
config_path = model_path / "config.json"
|
|
||||||
if not config_path.exists():
|
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
||||||
|
|
||||||
raw = load_model_config(str(model_path))
|
|
||||||
config = ConfigFactory.load(raw)
|
|
||||||
model_type = config.model_type or "autoregressive_lm"
|
|
||||||
|
|
||||||
actual_cls = AutoModel.get_component_class(model_type)
|
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
|
||||||
model = actual_cls(config)
|
|
||||||
|
|
||||||
weights_path = model_path / "model.safetensors"
|
|
||||||
if weights_path.exists():
|
|
||||||
state_dict = load_model_weights(str(model_path))
|
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def save_pretrained(
|
|
||||||
self,
|
|
||||||
save_directory: Union[str, Path],
|
|
||||||
):
|
|
||||||
save_model(
|
|
||||||
config=self.config.to_dict(),
|
|
||||||
state_dict=self.state_dict(),
|
|
||||||
save_directory=str(save_directory),
|
|
||||||
)
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> Self:
|
|
||||||
"""Move model to device/dtype."""
|
|
||||||
return super().to(*args, **kwargs)
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
|
||||||
from astrai.model.components.embedding import Embedding
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
from astrai.model.components.mlp import MLP
|
|
||||||
from astrai.model.components.norm import RMSNorm
|
|
||||||
from astrai.model.components.rope import (
|
|
||||||
RotaryEmbedding,
|
|
||||||
apply_rotary_emb,
|
|
||||||
get_rotary_emb,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Linear",
|
|
||||||
"RMSNorm",
|
|
||||||
"MLP",
|
|
||||||
"Embedding",
|
|
||||||
"GQA",
|
|
||||||
"MLA",
|
|
||||||
"DecoderBlock",
|
|
||||||
"RotaryEmbedding",
|
|
||||||
"apply_rotary_emb",
|
|
||||||
"get_rotary_emb",
|
|
||||||
"repeat_kv",
|
|
||||||
]
|
|
||||||
|
|
@ -1,212 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.inference.core.cache import KvcacheView
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
from astrai.model.components.norm import RMSNorm
|
|
||||||
from astrai.model.components.rope import apply_rotary_emb
|
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|
||||||
bs, slen, n_heads, head_dim = x.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return x
|
|
||||||
return (
|
|
||||||
x[:, :, :, None, :]
|
|
||||||
.expand(bs, slen, n_heads, n_rep, head_dim)
|
|
||||||
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AttnFactory(BaseFactory[nn.Module]):
|
|
||||||
@classmethod
|
|
||||||
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
|
||||||
return super().create(attn_type, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@AttnFactory.register("gqa")
|
|
||||||
class GQA(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_heads: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
use_qk_norm: bool,
|
|
||||||
norm_eps: float,
|
|
||||||
use_gated_attention: bool,
|
|
||||||
layer_id: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
assert dim % n_heads == 0
|
|
||||||
assert n_heads % n_kv_heads == 0
|
|
||||||
|
|
||||||
self.head_dim = dim // n_heads
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.dim = dim
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.n_kv_heads = n_kv_heads
|
|
||||||
self.n_rep = n_heads // n_kv_heads
|
|
||||||
self.use_qk_norm = use_qk_norm
|
|
||||||
self.use_gated_attention = use_gated_attention
|
|
||||||
|
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
|
||||||
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
|
||||||
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
|
||||||
self.o_proj = Linear(dim, dim)
|
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
|
||||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
|
||||||
|
|
||||||
if self.use_gated_attention:
|
|
||||||
self.gate = Linear(dim, dim)
|
|
||||||
|
|
||||||
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
|
||||||
batch_size, seq_len, _ = x.shape
|
|
||||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
rotary_emb: Tensor,
|
|
||||||
attn_mask: Tensor = None,
|
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
is_causal = attn_mask is None
|
|
||||||
|
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
|
||||||
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
|
||||||
|
|
||||||
if paged_cache is not None:
|
|
||||||
paged_cache.write(self.layer_id, k, v)
|
|
||||||
k, v = paged_cache.gather(self.layer_id)
|
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
|
||||||
|
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
|
||||||
sdqa_out = (
|
|
||||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.contiguous()
|
|
||||||
.flatten(2)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_gated_attention:
|
|
||||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
|
||||||
|
|
||||||
out = self.o_proj(sdqa_out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@AttnFactory.register("mla")
|
|
||||||
class MLA(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_heads: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
kv_lora_rank: int,
|
|
||||||
qk_nope_head_dim: int,
|
|
||||||
qk_rope_head_dim: int,
|
|
||||||
norm_eps: float,
|
|
||||||
use_qk_norm: bool,
|
|
||||||
use_gated_attention: bool,
|
|
||||||
layer_id: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.n_kv_heads = n_kv_heads
|
|
||||||
self.kv_lora_rank = kv_lora_rank
|
|
||||||
self.qk_nope_head_dim = qk_nope_head_dim
|
|
||||||
self.qk_rope_head_dim = qk_rope_head_dim
|
|
||||||
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.n_rep = n_heads // n_kv_heads
|
|
||||||
self.use_qk_norm = use_qk_norm
|
|
||||||
self.use_gated_attention = use_gated_attention
|
|
||||||
|
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
|
||||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
|
||||||
|
|
||||||
self.kv_b_proj = Linear(
|
|
||||||
kv_lora_rank,
|
|
||||||
n_kv_heads * (2 * self.head_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.o_proj = Linear(dim, dim, bias=False)
|
|
||||||
|
|
||||||
if use_gated_attention:
|
|
||||||
self.gate = Linear(dim, dim, bias=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
rotary_emb: Tensor,
|
|
||||||
attn_mask: Tensor = None,
|
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
bsz, seq_len, _ = x.size()
|
|
||||||
is_causal = attn_mask is None
|
|
||||||
|
|
||||||
q = self.q_proj(x)
|
|
||||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
|
||||||
|
|
||||||
kv_compressed = self.kv_a_proj(x)
|
|
||||||
kv_compressed = self.kv_norm(kv_compressed)
|
|
||||||
|
|
||||||
kv = self.kv_b_proj(kv_compressed)
|
|
||||||
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
|
||||||
|
|
||||||
k_nope, k_rope, v = torch.split(
|
|
||||||
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
q_nope, q_rope = (
|
|
||||||
q[..., : self.qk_nope_head_dim],
|
|
||||||
q[..., self.qk_nope_head_dim :],
|
|
||||||
)
|
|
||||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
|
||||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
|
||||||
|
|
||||||
q = torch.cat([q_nope, q_rope], dim=-1)
|
|
||||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
if paged_cache is not None:
|
|
||||||
paged_cache.write(self.layer_id, k, v)
|
|
||||||
k, v = paged_cache.gather(self.layer_id)
|
|
||||||
|
|
||||||
q = q.permute(0, 2, 1, 3)
|
|
||||||
k = k.permute(0, 2, 1, 3)
|
|
||||||
v = v.permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
attn_out = F.scaled_dot_product_attention(
|
|
||||||
q, k, v, attn_mask, is_causal=is_causal
|
|
||||||
)
|
|
||||||
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
|
||||||
|
|
||||||
if self.use_gated_attention:
|
|
||||||
attn_out = attn_out * F.sigmoid(self.gate(x))
|
|
||||||
|
|
||||||
out = self.o_proj(attn_out)
|
|
||||||
return out
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.inference.core.cache import KvcacheView
|
|
||||||
from astrai.model.components.attention import AttnFactory
|
|
||||||
from astrai.model.components.mlp import FFNFactory
|
|
||||||
from astrai.model.components.norm import RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_heads: int,
|
|
||||||
dim_ffn: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
norm_eps: float,
|
|
||||||
use_qk_norm: bool,
|
|
||||||
use_gated_attention: bool,
|
|
||||||
layer_id: int,
|
|
||||||
attn_type: str = "gqa",
|
|
||||||
ffn_type: str = "mlp",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.attention = AttnFactory.create(
|
|
||||||
attn_type,
|
|
||||||
dim=dim,
|
|
||||||
n_heads=n_heads,
|
|
||||||
n_kv_heads=n_kv_heads,
|
|
||||||
use_qk_norm=use_qk_norm,
|
|
||||||
norm_eps=norm_eps,
|
|
||||||
use_gated_attention=use_gated_attention,
|
|
||||||
layer_id=layer_id,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
|
||||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
rotary_emb: Tensor,
|
|
||||||
attention_mask: Optional[Tensor] = None,
|
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
attn_output = self.attention(
|
|
||||||
self.input_norm(x),
|
|
||||||
rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
paged_cache,
|
|
||||||
)
|
|
||||||
x = attn_output + x
|
|
||||||
x = self.mlp(self.post_attention_norm(x)) + x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.embedding(x, self.weight)
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
|
||||||
if self.bias is not None:
|
|
||||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
||||||
bound = 1 / (fan_in**0.5)
|
|
||||||
nn.init.uniform_(self.bias, -bound, bound)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.linear(x, self.weight, self.bias)
|
|
||||||
|
|
@ -1,194 +0,0 @@
|
||||||
import logging
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Set
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
from astrai.serialization import (
|
|
||||||
load_json,
|
|
||||||
load_safetensors,
|
|
||||||
save_json,
|
|
||||||
save_safetensors,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
|
||||||
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoRAConfig:
|
|
||||||
r: int = 16
|
|
||||||
alpha: int = 32
|
|
||||||
target_modules: tuple = ("q_proj", "v_proj")
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALinear(nn.Module):
|
|
||||||
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
|
||||||
super().__init__()
|
|
||||||
self.register_parameter("weight", base.weight)
|
|
||||||
self.weight.requires_grad_(False)
|
|
||||||
self.bias = base.bias
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias.requires_grad_(False)
|
|
||||||
|
|
||||||
self.r = r
|
|
||||||
self.scaling = alpha / r
|
|
||||||
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
|
||||||
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
|
||||||
self._merged = False
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.linear(x, self.weight, self.bias)
|
|
||||||
if not self._merged:
|
|
||||||
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
|
||||||
return out
|
|
||||||
|
|
||||||
def merge(self):
|
|
||||||
if self._merged:
|
|
||||||
return
|
|
||||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self._merged = True
|
|
||||||
del self.lora_A
|
|
||||||
del self.lora_B
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_lora_info(model: nn.Module) -> dict:
|
|
||||||
names = {}
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if isinstance(m, Linear):
|
|
||||||
_, _, child = n.rpartition(".")
|
|
||||||
names.setdefault(child, []).append(n)
|
|
||||||
return names
|
|
||||||
|
|
||||||
|
|
||||||
def _get_lora_count(model: nn.Module) -> int:
|
|
||||||
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
|
||||||
|
|
||||||
|
|
||||||
def inject_lora(
|
|
||||||
model: nn.Module,
|
|
||||||
r: int = 16,
|
|
||||||
alpha: int = 32,
|
|
||||||
target_modules: Optional[Set[str]] = None,
|
|
||||||
) -> LoRAConfig:
|
|
||||||
if target_modules is None:
|
|
||||||
target_modules = TARGET_MODULES_ATTN
|
|
||||||
|
|
||||||
available = _collect_lora_info(model)
|
|
||||||
injected = 0
|
|
||||||
|
|
||||||
for name, module in list(model.named_modules()):
|
|
||||||
if not isinstance(module, Linear):
|
|
||||||
continue
|
|
||||||
parent_name, _, child_name = name.rpartition(".")
|
|
||||||
if child_name not in target_modules:
|
|
||||||
continue
|
|
||||||
parent = model.get_submodule(parent_name) if parent_name else model
|
|
||||||
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
|
||||||
injected += 1
|
|
||||||
|
|
||||||
if injected == 0:
|
|
||||||
logger.warning(
|
|
||||||
"No LoRA layers injected. Available Linear child names: %s. "
|
|
||||||
"target_modules: %s. Check model type and target_modules.",
|
|
||||||
sorted(available),
|
|
||||||
sorted(target_modules),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
|
||||||
|
|
||||||
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
|
||||||
|
|
||||||
|
|
||||||
def merge_lora(model: nn.Module):
|
|
||||||
n = 0
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, LoRALinear):
|
|
||||||
module.merge()
|
|
||||||
n += 1
|
|
||||||
if n == 0:
|
|
||||||
logger.warning("No LoRA layers to merge.")
|
|
||||||
else:
|
|
||||||
logger.info("Merged %d LoRA layers", n)
|
|
||||||
|
|
||||||
|
|
||||||
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
|
||||||
lora_sd = {
|
|
||||||
k: v
|
|
||||||
for k, v in model.state_dict().items()
|
|
||||||
if k.endswith((".lora_A", ".lora_B"))
|
|
||||||
}
|
|
||||||
if not lora_sd:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LoRA parameters found in model. "
|
|
||||||
"The model may not have been injected or was already merged."
|
|
||||||
)
|
|
||||||
|
|
||||||
path = Path(save_dir)
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
|
||||||
save_json(asdict(config), path / "adapter_config.json")
|
|
||||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
|
||||||
path = Path(load_dir)
|
|
||||||
raw = load_json(path / "adapter_config.json")
|
|
||||||
config = LoRAConfig(
|
|
||||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
|
||||||
)
|
|
||||||
|
|
||||||
existing = _get_lora_count(model)
|
|
||||||
if existing > 0:
|
|
||||||
logger.warning(
|
|
||||||
"Model already has %d LoRA layers. Skipping injection, "
|
|
||||||
"loading weights onto existing layers only.",
|
|
||||||
existing,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inject_lora(
|
|
||||||
model,
|
|
||||||
r=config.r,
|
|
||||||
alpha=config.alpha,
|
|
||||||
target_modules=set(config.target_modules),
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = load_safetensors(path / "adapter_model.safetensors")
|
|
||||||
try:
|
|
||||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
|
||||||
except RuntimeError as e:
|
|
||||||
msg = str(e)
|
|
||||||
if "size mismatch" in msg:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"LoRA weight shapes do not match the model. "
|
|
||||||
f"The adapter config (r={config.r}) may not match the injected layers. "
|
|
||||||
f"Original error: {msg}"
|
|
||||||
) from e
|
|
||||||
raise
|
|
||||||
|
|
||||||
injected = _get_lora_count(model)
|
|
||||||
if injected == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LoRA layers found after loading. "
|
|
||||||
"Inject LoRA before calling load_lora, or check the adapter config."
|
|
||||||
)
|
|
||||||
|
|
||||||
if missing:
|
|
||||||
lora_missing = [k for k in missing if "lora" in k]
|
|
||||||
if lora_missing:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"LoRA weight keys not found in model: {lora_missing}. "
|
|
||||||
f"The adapter config (r={config.r}) may not match the model."
|
|
||||||
)
|
|
||||||
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
|
||||||
if unexpected:
|
|
||||||
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
|
||||||
|
|
||||||
logger.info("LoRA adapter loaded from %s", load_dir)
|
|
||||||
return config
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
|
|
||||||
|
|
||||||
class FFNFactory(BaseFactory[nn.Module]):
|
|
||||||
@classmethod
|
|
||||||
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
|
||||||
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@FFNFactory.register("mlp")
|
|
||||||
class MLP(nn.Module):
|
|
||||||
def __init__(self, dim: int, dim_ffn: int):
|
|
||||||
super().__init__()
|
|
||||||
self.up = Linear(dim, dim_ffn)
|
|
||||||
self.gate = Linear(dim, dim_ffn)
|
|
||||||
self.down = Linear(dim_ffn, dim)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
|
||||||
out = self.down(gated)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@FFNFactory.register("moe")
|
|
||||||
class DeepSeekMoE(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
dim_ffn: int,
|
|
||||||
n_routed_experts: int,
|
|
||||||
n_shared_experts: int = 1,
|
|
||||||
n_activated_experts: int = 2,
|
|
||||||
topk_method: str = "greedy",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.n_routed_experts = n_routed_experts
|
|
||||||
self.n_shared_experts = n_shared_experts
|
|
||||||
self.n_activated_experts = n_activated_experts
|
|
||||||
self.topk_method = topk_method
|
|
||||||
|
|
||||||
self.router = Linear(dim, n_routed_experts, bias=False)
|
|
||||||
|
|
||||||
self.shared_experts = nn.ModuleList(
|
|
||||||
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
|
||||||
)
|
|
||||||
self.routed_experts = nn.ModuleList(
|
|
||||||
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
bsz, seq_len, dim = x.shape
|
|
||||||
x_flat = x.view(-1, dim)
|
|
||||||
|
|
||||||
shared_out = self._shared_forward(x_flat)
|
|
||||||
routed_out = self._routed_forward(x_flat)
|
|
||||||
|
|
||||||
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _shared_forward(self, x: Tensor) -> Tensor:
|
|
||||||
if self.n_shared_experts == 0:
|
|
||||||
return torch.zeros_like(x)
|
|
||||||
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
|
||||||
|
|
||||||
def _routed_forward(self, x: Tensor) -> Tensor:
|
|
||||||
N, D = x.shape
|
|
||||||
K = self.n_activated_experts
|
|
||||||
|
|
||||||
router_logits = self.router(x)
|
|
||||||
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
|
||||||
|
|
||||||
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
|
||||||
for expert_idx in range(self.n_routed_experts):
|
|
||||||
expert_mask = topk_indices == expert_idx
|
|
||||||
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
|
||||||
if token_idx.numel() == 0:
|
|
||||||
continue
|
|
||||||
expert_input = x[token_idx]
|
|
||||||
expert_output = self.routed_experts[expert_idx](expert_input)
|
|
||||||
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
|
||||||
output.index_add_(0, token_idx, expert_output * weights)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim, norm_eps):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.normalized_shape = (dim,)
|
|
||||||
self.norm_eps = norm_eps
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
|
||||||
|
|
@ -1,71 +0,0 @@
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_emb(
|
|
||||||
dim: int,
|
|
||||||
max_len: int,
|
|
||||||
base: float = 10000,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
|
||||||
freqs = torch.outer(t, theta).float()
|
|
||||||
cos = torch.cos(freqs)
|
|
||||||
sin = torch.sin(freqs)
|
|
||||||
return torch.complex(cos, sin)
|
|
||||||
|
|
||||||
|
|
||||||
def ntk_base(base: float, dim: int, factor: float) -> float:
|
|
||||||
return base * (factor ** (dim / (dim - 2)))
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|
||||||
dtype = x.dtype
|
|
||||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
||||||
x_complex = torch.view_as_complex(x_)
|
|
||||||
freqs_cis = freqs_cis.unsqueeze(2)
|
|
||||||
x_rotated = x_complex * freqs_cis
|
|
||||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
|
||||||
return x_out.to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
max_len: int,
|
|
||||||
base: float = 10000,
|
|
||||||
rope_scaling: Optional[Dict] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.max_len = max_len
|
|
||||||
self.base = base
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
|
|
||||||
if rope_scaling is not None:
|
|
||||||
scaling_type = rope_scaling.get("type", "ntk")
|
|
||||||
factor = rope_scaling.get("factor", 1.0)
|
|
||||||
if scaling_type == "ntk":
|
|
||||||
self.base = ntk_base(base, dim, factor)
|
|
||||||
|
|
||||||
self._set_rotary_buffer(self.max_len)
|
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
|
||||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
|
||||||
freqs_cis = torch.view_as_real(rotary_emb)
|
|
||||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = (
|
|
||||||
torch.arange(x.size(1), device=x.device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(x.size(0), -1)
|
|
||||||
)
|
|
||||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
|
||||||
return torch.view_as_complex(position_freq_cis)
|
|
||||||
|
|
@ -1,99 +0,0 @@
|
||||||
from typing import Any, Mapping, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.config.model_config import EncoderConfig
|
|
||||||
from astrai.model.automodel import AutoModel
|
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
|
||||||
from astrai.model.components.embedding import Embedding
|
|
||||||
from astrai.model.components.norm import RMSNorm
|
|
||||||
from astrai.model.components.rope import RotaryEmbedding
|
|
||||||
from astrai.model.transformer import process_attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("embedding")
|
|
||||||
class EmbeddingEncoder(AutoModel):
|
|
||||||
def __init__(self, config: EncoderConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
rope_dim = config.dim // config.n_heads
|
|
||||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
|
||||||
self.rotary_embedding = RotaryEmbedding(
|
|
||||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
|
||||||
)
|
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
DecoderBlock(
|
|
||||||
config.dim,
|
|
||||||
config.n_heads,
|
|
||||||
config.dim_ffn,
|
|
||||||
config.n_kv_heads,
|
|
||||||
config.norm_eps,
|
|
||||||
config.use_qk_norm,
|
|
||||||
config.use_gated_attention,
|
|
||||||
layer_id,
|
|
||||||
)
|
|
||||||
for layer_id in range(config.n_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
|
||||||
|
|
||||||
self.pooling_type = config.pooling_type or "mean"
|
|
||||||
self.normalize_embeddings = config.normalize_embeddings or False
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if hasattr(module, "reset_parameters"):
|
|
||||||
module.reset_parameters()
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
|
||||||
state_dict = dict(state_dict)
|
|
||||||
state_dict.pop("lm_head.weight", None)
|
|
||||||
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
input_mask: Optional[Tensor] = None,
|
|
||||||
position_ids: Optional[Tensor] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
assert input_ids.ndim == 2
|
|
||||||
B, S = input_ids.shape
|
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
|
||||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
|
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
|
||||||
|
|
||||||
if self.pooling_type == "cls":
|
|
||||||
pooled = hidden_states[:, 0]
|
|
||||||
elif self.pooling_type == "last":
|
|
||||||
if input_mask is not None:
|
|
||||||
lengths = input_mask.sum(dim=1) - 1
|
|
||||||
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
|
|
||||||
else:
|
|
||||||
pooled = hidden_states[:, -1]
|
|
||||||
else:
|
|
||||||
if input_mask is not None:
|
|
||||||
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
|
||||||
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
|
|
||||||
min=1.0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pooled = hidden_states.mean(dim=1)
|
|
||||||
|
|
||||||
if self.normalize_embeddings:
|
|
||||||
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
|
|
||||||
|
|
||||||
return pooled
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
from typing import Any, Dict, Mapping, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
|
||||||
from astrai.inference.core.cache import KvcacheView
|
|
||||||
from astrai.model.automodel import AutoModel
|
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
|
||||||
from astrai.model.components.embedding import Embedding
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
from astrai.model.components.norm import RMSNorm
|
|
||||||
from astrai.model.components.rope import RotaryEmbedding
|
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
|
||||||
input_tensor: Tensor,
|
|
||||||
position_ids: Optional[Tensor],
|
|
||||||
input_mask: Optional[Tensor] = None,
|
|
||||||
is_causal: bool = False,
|
|
||||||
) -> Optional[Tensor]:
|
|
||||||
if position_ids is None:
|
|
||||||
return None
|
|
||||||
if input_mask is not None and input_mask.dim() > 2:
|
|
||||||
return input_mask
|
|
||||||
|
|
||||||
device = input_tensor.device
|
|
||||||
dtype = input_tensor.dtype
|
|
||||||
B, S = input_tensor.size()[:2]
|
|
||||||
T = position_ids.max().item() + 1
|
|
||||||
|
|
||||||
if input_mask is None:
|
|
||||||
if position_ids.min().item() == 0 and is_causal:
|
|
||||||
return None
|
|
||||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
|
||||||
else:
|
|
||||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
|
||||||
|
|
||||||
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
|
||||||
if is_causal:
|
|
||||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
|
||||||
|
|
||||||
return torch.full(
|
|
||||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
|
||||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("autoregressive_lm")
|
|
||||||
class AutoRegressiveLM(AutoModel):
|
|
||||||
"""Autoregressive language model with paged KV cache."""
|
|
||||||
|
|
||||||
def __init__(self, config: AutoRegressiveLMConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
rope_dim = (
|
|
||||||
config.qk_rope_head_dim
|
|
||||||
if config.attn_type == "mla"
|
|
||||||
else config.dim // config.n_heads
|
|
||||||
)
|
|
||||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
|
||||||
self.rotary_embedding = RotaryEmbedding(
|
|
||||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
|
||||||
)
|
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
DecoderBlock(
|
|
||||||
config.dim,
|
|
||||||
config.n_heads,
|
|
||||||
config.dim_ffn,
|
|
||||||
config.n_kv_heads,
|
|
||||||
config.norm_eps,
|
|
||||||
config.use_qk_norm,
|
|
||||||
config.use_gated_attention,
|
|
||||||
layer_id,
|
|
||||||
attn_type=config.attn_type,
|
|
||||||
ffn_type=config.ffn_type,
|
|
||||||
n_routed_experts=config.n_routed_experts,
|
|
||||||
n_shared_experts=config.n_shared_experts,
|
|
||||||
n_activated_experts=config.n_activated_experts,
|
|
||||||
topk_method=config.topk_method,
|
|
||||||
kv_lora_rank=config.kv_lora_rank,
|
|
||||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
||||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
||||||
)
|
|
||||||
for layer_id in range(config.n_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
|
||||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
|
||||||
|
|
||||||
if self.config.tie_weight is True:
|
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if hasattr(module, "reset_parameters"):
|
|
||||||
module.reset_parameters()
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
|
||||||
lm_head_key = "lm_head.weight"
|
|
||||||
embed_key = "embed_tokens.weight"
|
|
||||||
|
|
||||||
state_dict = dict(state_dict)
|
|
||||||
|
|
||||||
if self.config.tie_weight is True:
|
|
||||||
# same tensor for embed and lm_head
|
|
||||||
if embed_key in state_dict:
|
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
|
||||||
else:
|
|
||||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
|
||||||
# clone to avoid sharing gradients
|
|
||||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
|
||||||
|
|
||||||
return super().load_state_dict(state_dict, strict, assign)
|
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
|
||||||
state_dict = super().state_dict(
|
|
||||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.tie_weight is True:
|
|
||||||
lm_head_key = prefix + "lm_head.weight"
|
|
||||||
if lm_head_key in state_dict:
|
|
||||||
del state_dict[lm_head_key]
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
input_mask: Optional[Tensor] = None,
|
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
|
||||||
position_ids: Optional[Tensor] = None,
|
|
||||||
) -> Dict[str, Tensor]:
|
|
||||||
assert input_ids.ndim == 2
|
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
|
||||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
|
||||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x, rotary_emb, attn_mask, paged_cache)
|
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
return {"logits": logits, "hidden_states": hidden_states}
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
from astrai.parallel.executor import (
|
|
||||||
AccumOptimizer,
|
|
||||||
AccumScheduler,
|
|
||||||
BaseExecutor,
|
|
||||||
DDPExecutor,
|
|
||||||
ExecutorFactory,
|
|
||||||
FSDPExecutor,
|
|
||||||
GradientState,
|
|
||||||
NoneExecutor,
|
|
||||||
)
|
|
||||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
|
||||||
from astrai.parallel.setup import (
|
|
||||||
get_current_device,
|
|
||||||
get_rank,
|
|
||||||
get_world_size,
|
|
||||||
only_on_rank,
|
|
||||||
setup_parallel,
|
|
||||||
spawn_parallel_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"get_world_size",
|
|
||||||
"get_rank",
|
|
||||||
"get_current_device",
|
|
||||||
"only_on_rank",
|
|
||||||
"setup_parallel",
|
|
||||||
"spawn_parallel_fn",
|
|
||||||
"RowParallelLinear",
|
|
||||||
"ColumnParallelLinear",
|
|
||||||
"ExecutorFactory",
|
|
||||||
"BaseExecutor",
|
|
||||||
"GradientState",
|
|
||||||
"AccumOptimizer",
|
|
||||||
"AccumScheduler",
|
|
||||||
"NoneExecutor",
|
|
||||||
"DDPExecutor",
|
|
||||||
"FSDPExecutor",
|
|
||||||
]
|
|
||||||
|
|
@ -1,271 +0,0 @@
|
||||||
"""Unified training executor — parallel strategy + gradient accumulation."""
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.parallel.setup import get_rank, get_world_size
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GradientState:
|
|
||||||
def __init__(self, grad_accum_steps: int = 1):
|
|
||||||
self.num_steps = max(grad_accum_steps, 1)
|
|
||||||
self._step: int = 0
|
|
||||||
self._sync_gradients: bool = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sync_gradients(self) -> bool:
|
|
||||||
return self._sync_gradients
|
|
||||||
|
|
||||||
def _do_sync(self):
|
|
||||||
self._step += 1
|
|
||||||
self._sync_gradients = self._step % self.num_steps == 0
|
|
||||||
|
|
||||||
|
|
||||||
class AccumOptimizer:
|
|
||||||
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.gradient_state = gradient_state
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
|
||||||
if self.gradient_state.sync_gradients:
|
|
||||||
self.optimizer.step(closure)
|
|
||||||
|
|
||||||
def zero_grad(self):
|
|
||||||
if self.gradient_state.sync_gradients:
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def param_groups(self):
|
|
||||||
return self.optimizer.param_groups
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
return self.optimizer.state_dict()
|
|
||||||
|
|
||||||
def load_state_dict(self, d):
|
|
||||||
self.optimizer.load_state_dict(d)
|
|
||||||
|
|
||||||
|
|
||||||
class AccumScheduler:
|
|
||||||
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
|
||||||
self.scheduler = scheduler
|
|
||||||
self.gradient_state = gradient_state
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
if self.gradient_state.sync_gradients:
|
|
||||||
self.scheduler.step()
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
return self.scheduler.state_dict()
|
|
||||||
|
|
||||||
def load_state_dict(self, d):
|
|
||||||
self.scheduler.load_state_dict(d)
|
|
||||||
|
|
||||||
def get_last_lr(self):
|
|
||||||
return self.scheduler.get_last_lr()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseExecutor:
|
|
||||||
def __init__(self, grad_accum_steps: int = 1):
|
|
||||||
self.gradient_state = GradientState(grad_accum_steps)
|
|
||||||
|
|
||||||
def prepare(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
optimizer: Optional[Optimizer] = None,
|
|
||||||
dataloader: Optional[DataLoader] = None,
|
|
||||||
scheduler: Optional[LRScheduler] = None,
|
|
||||||
) -> Tuple[
|
|
||||||
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
|
|
||||||
]:
|
|
||||||
model = self._prepare_model(model)
|
|
||||||
if optimizer is not None:
|
|
||||||
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
|
||||||
if scheduler is not None:
|
|
||||||
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
|
||||||
return model, optimizer, dataloader, scheduler
|
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def accumulate(self, model: nn.Module):
|
|
||||||
self.gradient_state._do_sync()
|
|
||||||
if not self.gradient_state.sync_gradients:
|
|
||||||
with self._no_sync(model):
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_distributed(self) -> bool:
|
|
||||||
return get_world_size() > 1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sync_gradients(self) -> bool:
|
|
||||||
return self.gradient_state.sync_gradients
|
|
||||||
|
|
||||||
@property
|
|
||||||
def grad_accum_steps(self) -> int:
|
|
||||||
return self.gradient_state.num_steps
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("none")
|
|
||||||
class NoneExecutor(BaseExecutor):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("ddp")
|
|
||||||
class DDPExecutor(BaseExecutor):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grad_accum_steps: int = 1,
|
|
||||||
dim: int = 0,
|
|
||||||
broadcast_buffers: bool = True,
|
|
||||||
init_sync: bool = True,
|
|
||||||
process_group=None,
|
|
||||||
bucket_cap_mb: int = 25,
|
|
||||||
find_unused_parameters: bool = False,
|
|
||||||
check_reduction: bool = False,
|
|
||||||
gradient_as_bucket_view: bool = False,
|
|
||||||
static_graph: bool = False,
|
|
||||||
delay_all_reduce_named_params=None,
|
|
||||||
param_to_hook_all_reduce=None,
|
|
||||||
mixed_precision=None,
|
|
||||||
device_mesh=None,
|
|
||||||
):
|
|
||||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
|
||||||
self._ddp_kwargs = dict(
|
|
||||||
dim=dim,
|
|
||||||
broadcast_buffers=broadcast_buffers,
|
|
||||||
init_sync=init_sync,
|
|
||||||
process_group=process_group,
|
|
||||||
bucket_cap_mb=bucket_cap_mb,
|
|
||||||
find_unused_parameters=find_unused_parameters,
|
|
||||||
check_reduction=check_reduction,
|
|
||||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
|
||||||
static_graph=static_graph,
|
|
||||||
delay_all_reduce_named_params=delay_all_reduce_named_params,
|
|
||||||
param_to_hook_all_reduce=param_to_hook_all_reduce,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
device_mesh=device_mesh,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
if not self.use_distributed:
|
|
||||||
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
|
||||||
return model
|
|
||||||
local_rank = get_rank()
|
|
||||||
model = DDP(
|
|
||||||
model,
|
|
||||||
device_ids=[local_rank],
|
|
||||||
output_device=local_rank,
|
|
||||||
**self._ddp_kwargs,
|
|
||||||
)
|
|
||||||
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
return model.no_sync()
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
return model.module.state_dict()
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("fsdp")
|
|
||||||
class FSDPExecutor(BaseExecutor):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grad_accum_steps: int = 1,
|
|
||||||
process_group=None,
|
|
||||||
sharding_strategy=None,
|
|
||||||
cpu_offload=None,
|
|
||||||
auto_wrap_policy=None,
|
|
||||||
backward_prefetch=None,
|
|
||||||
mixed_precision=None,
|
|
||||||
ignored_modules=None,
|
|
||||||
param_init_fn=None,
|
|
||||||
sync_module_states: bool = False,
|
|
||||||
forward_prefetch: bool = False,
|
|
||||||
limit_all_gathers: bool = True,
|
|
||||||
ignored_states=None,
|
|
||||||
device_mesh=None,
|
|
||||||
):
|
|
||||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
|
||||||
self._fsdp_kwargs = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
process_group=process_group,
|
|
||||||
sharding_strategy=sharding_strategy,
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
|
||||||
backward_prefetch=backward_prefetch,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
ignored_modules=ignored_modules,
|
|
||||||
param_init_fn=param_init_fn,
|
|
||||||
sync_module_states=sync_module_states,
|
|
||||||
forward_prefetch=forward_prefetch,
|
|
||||||
limit_all_gathers=limit_all_gathers,
|
|
||||||
use_orig_params=True,
|
|
||||||
ignored_states=ignored_states,
|
|
||||||
device_mesh=device_mesh,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
self._original_model: Optional[nn.Module] = None
|
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
if not self.use_distributed:
|
|
||||||
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
|
||||||
return model
|
|
||||||
self._original_model = model
|
|
||||||
device_id = torch.device("cuda", get_rank())
|
|
||||||
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
|
||||||
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
|
||||||
if isinstance(model, FSDP):
|
|
||||||
return model.no_sync()
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
|
||||||
if isinstance(model, FSDP) and self.use_distributed:
|
|
||||||
with FSDP.state_dict_type(
|
|
||||||
model,
|
|
||||||
StateDictType.FULL_STATE_DICT,
|
|
||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
|
||||||
):
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
from astrai.preprocessing.builder import (
|
|
||||||
BaseMaskBuilder,
|
|
||||||
MaskBuilderFactory,
|
|
||||||
SectionedMaskBuilder,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseMaskBuilder",
|
|
||||||
"MaskBuilderFactory",
|
|
||||||
"SectionedMaskBuilder",
|
|
||||||
"Pipeline",
|
|
||||||
"filter_by_length",
|
|
||||||
]
|
|
||||||
|
|
@ -1,159 +0,0 @@
|
||||||
"""Mask building strategies for preprocessing pipeline.
|
|
||||||
|
|
||||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
|
||||||
via declarative ``input.sections`` config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMaskBuilder(ABC):
|
|
||||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
|
||||||
|
|
||||||
Returns ``None`` to skip the item entirely.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, BaseMaskBuilder):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
|
||||||
if not domain_key:
|
|
||||||
return "__default__"
|
|
||||||
val = item.get(domain_key, "__default__")
|
|
||||||
return val if isinstance(val, str) else "__default__"
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_action(action: str, role: str, config) -> str:
|
|
||||||
"""Resolve action to "train" or "mask".
|
|
||||||
|
|
||||||
- ``"train"`` / ``"mask"`` → literal
|
|
||||||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
|
||||||
"""
|
|
||||||
if action == "$role":
|
|
||||||
return config.mask.get(role, config.mask_default)
|
|
||||||
return action
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("sectioned")
|
|
||||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
|
||||||
"""Config-driven builder: iterates over ``input.sections`` in order.
|
|
||||||
|
|
||||||
Each section specifies a JSONL field + mask action.
|
|
||||||
|
|
||||||
Section spec::
|
|
||||||
|
|
||||||
{
|
|
||||||
"field": "messages", # JSONL key
|
|
||||||
"action": "$role", # "train" | "mask" | "$role"
|
|
||||||
"template": true, # apply chat_template per message (optional)
|
|
||||||
"add_special_tokens": false # override encode flag (optional)
|
|
||||||
}
|
|
||||||
|
|
||||||
Example configs::
|
|
||||||
|
|
||||||
# Chat
|
|
||||||
{"input": {"sections": [
|
|
||||||
{"field": "messages", "action": "$role", "template": true}
|
|
||||||
]}}
|
|
||||||
|
|
||||||
# Instruction
|
|
||||||
{"input": {"sections": [
|
|
||||||
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
|
||||||
{"field": "response", "action": "train"}
|
|
||||||
]}}
|
|
||||||
|
|
||||||
# Text
|
|
||||||
{"input": {"sections": [
|
|
||||||
{"field": "text", "action": "train"}
|
|
||||||
]}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
sections = config.input.sections
|
|
||||||
if not sections:
|
|
||||||
return None
|
|
||||||
|
|
||||||
all_ids: list[int] = []
|
|
||||||
loss_mask: list[int] = []
|
|
||||||
|
|
||||||
has_template = any(s.get("template") for s in sections)
|
|
||||||
is_text_config = not has_template and all(
|
|
||||||
s["action"] == "train" for s in sections
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_template and tokenizer.bos_token_id is not None:
|
|
||||||
all_ids.append(tokenizer.bos_token_id)
|
|
||||||
loss_mask.append(0)
|
|
||||||
|
|
||||||
first_section = True
|
|
||||||
for sec in sections:
|
|
||||||
field = sec["field"]
|
|
||||||
action = sec["action"]
|
|
||||||
use_template = sec.get("template", False)
|
|
||||||
add_special = sec.get(
|
|
||||||
"add_special_tokens", not use_template and first_section
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_template:
|
|
||||||
messages = item.get(field)
|
|
||||||
if not isinstance(messages, list) or not messages:
|
|
||||||
continue
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
act = _resolve_action(action, role, config)
|
|
||||||
rendered = tokenizer.apply_chat_template(
|
|
||||||
[msg], tokenize=False, add_generation_prompt=False
|
|
||||||
)
|
|
||||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
val = 1 if act == "train" else 0
|
|
||||||
loss_mask.extend([val] * len(ids))
|
|
||||||
else:
|
|
||||||
text = str(item.get(field, ""))
|
|
||||||
if not text.strip():
|
|
||||||
continue
|
|
||||||
if is_text_config:
|
|
||||||
pp = config.preprocessing
|
|
||||||
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
|
||||||
continue
|
|
||||||
if len(text) > pp.max_chars:
|
|
||||||
continue
|
|
||||||
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
val = 1 if action == "train" else 0
|
|
||||||
loss_mask.extend([val] * len(ids))
|
|
||||||
|
|
||||||
first_section = False
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
|
||||||
all_ids = all_ids[:max_len]
|
|
||||||
loss_mask = loss_mask[: len(all_ids)]
|
|
||||||
|
|
||||||
if not all_ids:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if has_template and len(all_ids) <= 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result: dict = {
|
|
||||||
"sequence": all_ids,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
if not all(m == 1 for m in loss_mask):
|
|
||||||
result["loss_mask"] = loss_mask
|
|
||||||
return result
|
|
||||||
|
|
@ -1,141 +0,0 @@
|
||||||
"""Config-driven JSONL preprocessing pipeline.
|
|
||||||
|
|
||||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
|
||||||
sharding and flush to ``.h5`` / ``.bin`` storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from itertools import chain
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
|
||||||
from astrai.dataset.storage import save_bin, save_h5
|
|
||||||
from astrai.preprocessing.builder import SectionedMaskBuilder
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
|
||||||
"bool": torch.bool,
|
|
||||||
"uint8": torch.uint8,
|
|
||||||
"int8": torch.int8,
|
|
||||||
"int16": torch.int16,
|
|
||||||
"int32": torch.int32,
|
|
||||||
"int64": torch.int64,
|
|
||||||
"float16": torch.float16,
|
|
||||||
"float32": torch.float32,
|
|
||||||
"float64": torch.float64,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
|
||||||
return min_len <= len(text) <= max_len
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
|
||||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
|
||||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PipelineConfig,
|
|
||||||
input_paths: list[str],
|
|
||||||
output_dir: str,
|
|
||||||
tokenizer_path: str,
|
|
||||||
):
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
self.config = config
|
|
||||||
self.paths = input_paths
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.tokenizer_path = tokenizer_path
|
|
||||||
|
|
||||||
self.mask_builder = SectionedMaskBuilder()
|
|
||||||
|
|
||||||
def transform(self, item: dict) -> Optional[dict]:
|
|
||||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
|
||||||
domains: dict = defaultdict(lambda: defaultdict(list))
|
|
||||||
total_tokens = 0
|
|
||||||
shard_idx: dict[str, int] = defaultdict(int)
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
pp = self.config.preprocessing
|
|
||||||
|
|
||||||
for item in tqdm.tqdm(
|
|
||||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
|
||||||
):
|
|
||||||
if pp.max_items and count >= pp.max_items:
|
|
||||||
break
|
|
||||||
|
|
||||||
result = self.transform(item)
|
|
||||||
if result is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
ids = result.pop("sequence")
|
|
||||||
if not ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
domain = result.pop("domain", "__default__")
|
|
||||||
result["sequence"] = ids
|
|
||||||
|
|
||||||
bucket = domains[domain]
|
|
||||||
for key in list(bucket.keys()):
|
|
||||||
if key not in result:
|
|
||||||
bucket[key].append([1] * len(ids))
|
|
||||||
for key, val in result.items():
|
|
||||||
bucket[key].append(val)
|
|
||||||
|
|
||||||
count += 1
|
|
||||||
total_tokens += len(ids)
|
|
||||||
|
|
||||||
if total_tokens >= self.config.output.max_tokens_per_shard:
|
|
||||||
self._flush(domains, shard_idx)
|
|
||||||
domains.clear()
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if total_tokens > 0:
|
|
||||||
self._flush(domains, shard_idx)
|
|
||||||
|
|
||||||
print(f"Done. {count} documents tokenized.")
|
|
||||||
|
|
||||||
def _iter_items(self):
|
|
||||||
for path in self.paths:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
yield json.loads(line)
|
|
||||||
|
|
||||||
def _flush(self, domains, shard_idx):
|
|
||||||
for domain, keys in domains.items():
|
|
||||||
idx = shard_idx[domain]
|
|
||||||
tensors = {}
|
|
||||||
for key, ids_list in keys.items():
|
|
||||||
dt = _STR_TO_DTYPE.get(
|
|
||||||
self.config.output.dtype.get(key, "int32"), torch.int32
|
|
||||||
)
|
|
||||||
tensors[key] = [
|
|
||||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
|
||||||
]
|
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
|
||||||
fmt = self.config.output.storage_format
|
|
||||||
if fmt == "bin":
|
|
||||||
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
|
|
||||||
else:
|
|
||||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
|
||||||
shard_idx[domain] = idx + 1
|
|
||||||
tqdm.tqdm.write(
|
|
||||||
f" saved {domain}/shard_{idx:04d} "
|
|
||||||
f"({tensors['sequence'][0].numel():,} tokens)"
|
|
||||||
)
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
|
||||||
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class OptimizerProtocol(Protocol):
|
|
||||||
def step(self, closure=None): ...
|
|
||||||
def zero_grad(self): ...
|
|
||||||
@property
|
|
||||||
def param_groups(self) -> Any: ...
|
|
||||||
def state_dict(self) -> dict: ...
|
|
||||||
def load_state_dict(self, d: dict): ...
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class SchedulerProtocol(Protocol):
|
|
||||||
def step(self): ...
|
|
||||||
def state_dict(self) -> dict: ...
|
|
||||||
def load_state_dict(self, d: dict): ...
|
|
||||||
def get_last_lr(self): ...
|
|
||||||
|
|
@ -1,182 +0,0 @@
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Union
|
|
||||||
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from astrai.parallel.setup import get_rank
|
|
||||||
|
|
||||||
_META_FILE = "meta.json"
|
|
||||||
_CONFIG_FILE = "config.json"
|
|
||||||
_WEIGHTS_FILE = "model.safetensors"
|
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
|
||||||
st.save_file(state_dict, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
return st.load_file(str(path))
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
if rank == 0:
|
|
||||||
state_dict = st.load_file(str(path))
|
|
||||||
else:
|
|
||||||
state_dict = {}
|
|
||||||
tmp = [state_dict]
|
|
||||||
dist.broadcast_object_list(tmp, src=0)
|
|
||||||
return tmp[0]
|
|
||||||
|
|
||||||
|
|
||||||
def save_json(data: dict, path: Union[str, Path]):
|
|
||||||
with open(str(path), "w") as f:
|
|
||||||
json.dump(data, f, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
with open(str(path), "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
if rank == 0:
|
|
||||||
with open(str(path), "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
else:
|
|
||||||
data = {}
|
|
||||||
tmp = [data]
|
|
||||||
dist.broadcast_object_list(tmp, src=0)
|
|
||||||
return tmp[0]
|
|
||||||
|
|
||||||
|
|
||||||
def save_torch(obj: Any, path: Union[str, Path]):
|
|
||||||
torch.save(obj, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
path = Path(path)
|
|
||||||
rank = get_rank()
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
raw = f.read()
|
|
||||||
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
|
||||||
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
|
||||||
else:
|
|
||||||
num_bytes = torch.tensor([0], dtype=torch.long)
|
|
||||||
|
|
||||||
dist.broadcast(num_bytes, src=0)
|
|
||||||
|
|
||||||
if rank != 0:
|
|
||||||
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
|
||||||
|
|
||||||
dist.broadcast(data_tensor, src=0)
|
|
||||||
|
|
||||||
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
|
||||||
return torch.load(buf, map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(config: dict, state_dict: dict, save_directory: str):
|
|
||||||
save_path = Path(save_directory)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_json(config, save_path / _CONFIG_FILE)
|
|
||||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(save_directory: str) -> dict:
|
|
||||||
return load_json(Path(save_directory) / _CONFIG_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(save_directory: str) -> dict:
|
|
||||||
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
|
||||||
path = Path(path)
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
return load_safetensors(path)
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
if rank == 0:
|
|
||||||
state_dict = load_safetensors(path)
|
|
||||||
specs = [
|
|
||||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
|
||||||
for k in sorted(state_dict)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
state_dict = {}
|
|
||||||
specs = []
|
|
||||||
|
|
||||||
specs_list = [specs]
|
|
||||||
dist.broadcast_object_list(specs_list, src=0)
|
|
||||||
specs = specs_list[0]
|
|
||||||
|
|
||||||
for key, shape, dtype_name in specs:
|
|
||||||
dtype = getattr(torch, dtype_name)
|
|
||||||
if rank != 0:
|
|
||||||
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
|
||||||
else:
|
|
||||||
tensor = state_dict[key].contiguous().cpu()
|
|
||||||
dist.broadcast(tensor, src=0)
|
|
||||||
if rank != 0:
|
|
||||||
state_dict[key] = tensor
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Checkpoint:
|
|
||||||
state_dict: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
epoch: int = 0
|
|
||||||
iteration: int = 0
|
|
||||||
extra: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
config: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def save(self, save_dir: str):
|
|
||||||
save_path = Path(save_dir)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
if get_rank() != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
meta = {
|
|
||||||
"epoch": self.epoch,
|
|
||||||
"iteration": self.iteration,
|
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
||||||
**self.meta,
|
|
||||||
}
|
|
||||||
save_json(meta, save_path / _META_FILE)
|
|
||||||
save_json(self.config, save_path / _CONFIG_FILE)
|
|
||||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
|
||||||
for key, value in self.extra.items():
|
|
||||||
save_torch(value, save_path / f"{key}.pt")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
|
||||||
save_path = Path(save_dir)
|
|
||||||
|
|
||||||
meta = load_json(save_path / _META_FILE, broadcast)
|
|
||||||
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
|
||||||
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
|
||||||
|
|
||||||
extra = {}
|
|
||||||
for f in sorted(save_path.iterdir()):
|
|
||||||
if f.suffix == ".pt":
|
|
||||||
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
state_dict=state_dict,
|
|
||||||
epoch=meta.get("epoch", 0),
|
|
||||||
iteration=meta.get("iteration", 0),
|
|
||||||
extra=extra,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AutoTokenizer",
|
|
||||||
"ChatTemplate",
|
|
||||||
"MessageType",
|
|
||||||
]
|
|
||||||
|
|
@ -1,74 +0,0 @@
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from jinja2 import Template
|
|
||||||
|
|
||||||
type MessageType = Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate:
|
|
||||||
"""A chat template with Jinja2 rendering support.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name: Unique identifier for the template.
|
|
||||||
template_str: Jinja2 template string.
|
|
||||||
description: Optional description.
|
|
||||||
default_variables: Optional dictionary of default variable values.
|
|
||||||
special_tokens: Optional dictionary mapping token names to their string values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str = "",
|
|
||||||
template_str: str = "",
|
|
||||||
description: str = "",
|
|
||||||
default_variables: Optional[Dict[str, Any]] = None,
|
|
||||||
special_tokens: Optional[Dict[str, str]] = None,
|
|
||||||
):
|
|
||||||
self.name = name
|
|
||||||
self.template_str = template_str
|
|
||||||
self.description = description
|
|
||||||
self.default_variables = default_variables or {}
|
|
||||||
self.special_tokens = special_tokens or {}
|
|
||||||
self._compiled: Template = Template(template_str)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_string(
|
|
||||||
cls,
|
|
||||||
template_str: str,
|
|
||||||
description: str = "",
|
|
||||||
default_variables: Optional[Dict[str, Any]] = None,
|
|
||||||
special_tokens: Optional[Dict[str, str]] = None,
|
|
||||||
) -> "ChatTemplate":
|
|
||||||
"""Create a ChatTemplate instance directly from a template string."""
|
|
||||||
return cls(
|
|
||||||
name="",
|
|
||||||
template_str=template_str,
|
|
||||||
description=description,
|
|
||||||
default_variables=default_variables,
|
|
||||||
special_tokens=special_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def render(
|
|
||||||
self,
|
|
||||||
messages: List[MessageType],
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
**extra_variables: Any,
|
|
||||||
) -> str:
|
|
||||||
"""Render the template with given messages and variables.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of message dicts with 'role' and 'content'.
|
|
||||||
system_prompt: Optional system prompt string.
|
|
||||||
**extra_variables: Additional variables to pass to the template.
|
|
||||||
These override default_variables and special_tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rendered prompt string.
|
|
||||||
"""
|
|
||||||
# Merge default variables, special tokens, and extra variables
|
|
||||||
variables = {**self.default_variables, **self.special_tokens, **extra_variables}
|
|
||||||
variables["messages"] = messages
|
|
||||||
if system_prompt is not None:
|
|
||||||
variables["system_prompt"] = system_prompt
|
|
||||||
|
|
||||||
return self._compiled.render(**variables)
|
|
||||||
|
|
@ -1,264 +0,0 @@
|
||||||
"""
|
|
||||||
Tokenizer module with implementation and auto-loading support.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from tokenizers import Tokenizer
|
|
||||||
|
|
||||||
from astrai.tokenize.chat_template import ChatTemplate
|
|
||||||
|
|
||||||
|
|
||||||
class AutoTokenizer:
|
|
||||||
"""Base tokenizer class with automatic loading support"""
|
|
||||||
|
|
||||||
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
path: Optional[Union[str, Path]] = None,
|
|
||||||
special_token_map: Optional[Dict[str, str]] = None,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self._tokenizer: Tokenizer = None
|
|
||||||
self._chat_template: Optional[ChatTemplate] = None
|
|
||||||
self._special_token_map: Optional[Dict] = special_token_map or {}
|
|
||||||
|
|
||||||
if chat_template:
|
|
||||||
self.set_chat_template(chat_template)
|
|
||||||
|
|
||||||
if path:
|
|
||||||
self.load(path)
|
|
||||||
|
|
||||||
def load(self, path: Union[str, Path]):
|
|
||||||
"""Load tokenizer from directory."""
|
|
||||||
path = Path(path)
|
|
||||||
tokenizer_file = path / "tokenizer.json"
|
|
||||||
config_file = path / "tokenizer_config.json"
|
|
||||||
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
|
||||||
|
|
||||||
if config_file.exists():
|
|
||||||
with open(config_file, "r", encoding="utf-8") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
if "special_tokens" in config:
|
|
||||||
self._special_token_map.update(config["special_tokens"])
|
|
||||||
|
|
||||||
# Load chat template from config
|
|
||||||
if "chat_template" in config:
|
|
||||||
self.set_chat_template(config["chat_template"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
|
|
||||||
"""Load tokenizer from pretrained directory.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If tokenizer.json is missing.
|
|
||||||
RuntimeError: If tokenizer failed to initialize.
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
tokenizer_file = path / "tokenizer.json"
|
|
||||||
if not tokenizer_file.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Tokenizer file not found: {tokenizer_file}. "
|
|
||||||
"A valid tokenizer.json is required."
|
|
||||||
)
|
|
||||||
instance = cls(path)
|
|
||||||
if instance._tokenizer is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to load tokenizer from {path}. "
|
|
||||||
"The tokenizer.json may be corrupted or incompatible."
|
|
||||||
)
|
|
||||||
return instance
|
|
||||||
|
|
||||||
def save_pretrained(self, save_path: str):
|
|
||||||
"""
|
|
||||||
Save tokenizer to pretrained directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
save_path: Path to save the tokenizer
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self._tokenizer is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Tokenizer not initialized. Load or create a tokenizer first."
|
|
||||||
)
|
|
||||||
|
|
||||||
save_path = Path(save_path)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Save tokenizer
|
|
||||||
self._tokenizer.save(str(save_path / "tokenizer.json"))
|
|
||||||
|
|
||||||
# Save tokenizer config
|
|
||||||
config = {}
|
|
||||||
if self._special_token_map is not None:
|
|
||||||
config["special_tokens"] = self._special_token_map
|
|
||||||
if self._chat_template is not None:
|
|
||||||
config["chat_template"] = self._chat_template.template_str
|
|
||||||
|
|
||||||
with open(save_path / "tokenizer_config.json", "w", encoding="utf-8") as f:
|
|
||||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_tokenizer(cls, name: str, tokenizer_class: type):
|
|
||||||
"""
|
|
||||||
Register a new tokenizer class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name to register the tokenizer class under
|
|
||||||
tokenizer_class: The tokenizer class to register
|
|
||||||
"""
|
|
||||||
cls.TOKENIZER_CLASSES[name] = tokenizer_class
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
tokens: Union[str, List[str]],
|
|
||||||
out_ids: bool = True,
|
|
||||||
is_pretokenized: bool = False,
|
|
||||||
add_special_tokens: bool = True,
|
|
||||||
) -> List:
|
|
||||||
"""Encode text to tokens or token IDs."""
|
|
||||||
if self._tokenizer is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Tokenizer not initialized. Load or create a tokenizer first."
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(tokens, str):
|
|
||||||
encoded = self._tokenizer.encode(
|
|
||||||
tokens,
|
|
||||||
is_pretokenized=is_pretokenized,
|
|
||||||
add_special_tokens=add_special_tokens,
|
|
||||||
)
|
|
||||||
return encoded.ids if out_ids else encoded.tokens
|
|
||||||
else:
|
|
||||||
encoded_list = self._tokenizer.encode_batch(
|
|
||||||
tokens,
|
|
||||||
is_pretokenized=is_pretokenized,
|
|
||||||
add_special_tokens=add_special_tokens,
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
|
||||||
]
|
|
||||||
|
|
||||||
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
|
||||||
"""Decode token IDs to text."""
|
|
||||||
if self._tokenizer is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Tokenizer not initialized. Load or create a tokenizer first."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
if self._tokenizer is None:
|
|
||||||
return 0
|
|
||||||
return self._tokenizer.get_vocab_size()
|
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
|
||||||
"""
|
|
||||||
Dynamically intercept special token attribute access.
|
|
||||||
Supports three forms:
|
|
||||||
- tokenizer.bos_token → returns string
|
|
||||||
- tokenizer.bos_token_id → returns corresponding integer ID
|
|
||||||
- tokenizer.stop_ids → returns list of corresponding integer IDs for all special tokens
|
|
||||||
"""
|
|
||||||
# Handle stop_ids - return IDs for all special tokens
|
|
||||||
if key == "stop_ids":
|
|
||||||
stop_ids = []
|
|
||||||
|
|
||||||
if self._tokenizer is None:
|
|
||||||
return stop_ids
|
|
||||||
|
|
||||||
for val in self._special_token_map.values():
|
|
||||||
token_id = self._tokenizer.token_to_id(val)
|
|
||||||
if token_id is not None:
|
|
||||||
stop_ids.append(token_id)
|
|
||||||
|
|
||||||
return stop_ids
|
|
||||||
|
|
||||||
# Handle _id suffix (e.g., bos_token_id -> bos_token)
|
|
||||||
if key.endswith("_id"):
|
|
||||||
base_attr = key[:-3] # Remove "_id"
|
|
||||||
token_str = self._special_token_map.get(base_attr)
|
|
||||||
if token_str is None:
|
|
||||||
return None
|
|
||||||
if self._tokenizer is None:
|
|
||||||
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
|
|
||||||
return self._tokenizer.token_to_id(token_str)
|
|
||||||
|
|
||||||
# Handle regular string attributes
|
|
||||||
if key in self._special_token_map:
|
|
||||||
return self._special_token_map.get(key)
|
|
||||||
|
|
||||||
# Other attributes trigger default AttributeError
|
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self) -> int:
|
|
||||||
return len(self)
|
|
||||||
|
|
||||||
def set_chat_template(self, template: Union[str, ChatTemplate]):
|
|
||||||
"""
|
|
||||||
Set the chat template for the tokenizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: Either a template name (str) registered in the global registry,
|
|
||||||
or a ChatTemplate instance, or a Jinja2 template string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If template name is not registered.
|
|
||||||
"""
|
|
||||||
if isinstance(template, str):
|
|
||||||
self._chat_template = ChatTemplate.from_string(template)
|
|
||||||
elif isinstance(template, ChatTemplate):
|
|
||||||
self._chat_template = template
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid template type, must be str or ChatTemplate.")
|
|
||||||
|
|
||||||
def apply_chat_template(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
tokenize: bool = True,
|
|
||||||
add_generation_prompt: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[str, List[int]]:
|
|
||||||
"""
|
|
||||||
Apply the chat template to messages and optionally tokenize the result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of message dicts with 'role' and 'content'.
|
|
||||||
system_prompt: Optional system prompt string (auto-converted to first message).
|
|
||||||
tokenize: Whether to return token IDs (True) or raw string (False).
|
|
||||||
add_generation_prompt: Whether to add the generation prompt (default: True).
|
|
||||||
**kwargs: Additional variables to pass to the template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either the rendered string or list of token IDs.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If chat template is not set.
|
|
||||||
"""
|
|
||||||
if self._chat_template is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Chat template not set. Use set_chat_template() to set a template first."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-convert system_prompt to first message if provided
|
|
||||||
if system_prompt:
|
|
||||||
messages = [{"role": "system", "content": system_prompt}] + list(messages)
|
|
||||||
|
|
||||||
# Render the template
|
|
||||||
rendered = self._chat_template.render(
|
|
||||||
messages=messages,
|
|
||||||
add_generation_prompt=add_generation_prompt,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tokenize:
|
|
||||||
return self.encode(rendered)
|
|
||||||
|
|
||||||
return rendered
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
||||||
from astrai.trainer.optim import Muon
|
|
||||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
|
||||||
from astrai.trainer.train_callback import (
|
|
||||||
CallbackFactory,
|
|
||||||
TrainCallback,
|
|
||||||
)
|
|
||||||
from astrai.trainer.trainer import Trainer
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Main trainer
|
|
||||||
"Trainer",
|
|
||||||
# Optimizer
|
|
||||||
"Muon",
|
|
||||||
# Strategy factory
|
|
||||||
"StrategyFactory",
|
|
||||||
"BaseStrategy",
|
|
||||||
# Scheduler factory
|
|
||||||
"SchedulerFactory",
|
|
||||||
"BaseScheduler",
|
|
||||||
# Callback factory
|
|
||||||
"TrainCallback",
|
|
||||||
"CallbackFactory",
|
|
||||||
]
|
|
||||||
|
|
@ -1,75 +0,0 @@
|
||||||
from typing import Any, Callable, Dict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def _grad_stat(
|
|
||||||
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
|
|
||||||
) -> dict:
|
|
||||||
results = {}
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
results[name] = default
|
|
||||||
if param.grad is not None:
|
|
||||||
results[name] = fn(param.grad.data)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
|
||||||
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def grad_std(model: nn.Module) -> Dict[str, float]:
|
|
||||||
return _grad_stat(model, lambda g: g.std().item(), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def grad_max(model: nn.Module) -> Dict[str, float]:
|
|
||||||
return _grad_stat(model, lambda g: g.max().item(), -float("inf"))
|
|
||||||
|
|
||||||
|
|
||||||
def grad_min(model: nn.Module) -> Dict[str, float]:
|
|
||||||
return _grad_stat(model, lambda g: g.min().item(), float("inf"))
|
|
||||||
|
|
||||||
|
|
||||||
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
|
||||||
return _grad_stat(model, lambda g: g.mean().item(), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
|
||||||
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0)
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_loss(ctx):
|
|
||||||
return ctx.loss
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_lr(ctx):
|
|
||||||
return ctx.optimizer.param_groups[-1]["lr"]
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_val_loss(ctx):
|
|
||||||
return ctx.val_loss
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_norm(ctx):
|
|
||||||
return grad_norm(ctx.model)
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_std(ctx):
|
|
||||||
return grad_std(ctx.model)
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_max(ctx):
|
|
||||||
return grad_max(ctx.model)
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_min(ctx):
|
|
||||||
return grad_min(ctx.model)
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_mean(ctx):
|
|
||||||
return grad_mean(ctx.model)
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_nan_num(ctx):
|
|
||||||
return grad_nan_num(ctx.model)
|
|
||||||
|
|
@ -1,143 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
|
||||||
assert G.ndim == 2
|
|
||||||
X = G
|
|
||||||
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
|
||||||
X = X / (X.norm() + 1e-7) * scale
|
|
||||||
if steps == 0:
|
|
||||||
return X
|
|
||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
||||||
for _ in range(steps):
|
|
||||||
A = X @ X.T
|
|
||||||
B = A @ X
|
|
||||||
X = a * X + b * B + c * (A @ B)
|
|
||||||
return X
|
|
||||||
|
|
||||||
|
|
||||||
class Muon(Optimizer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr: float = 2e-3,
|
|
||||||
momentum: float = 0.95,
|
|
||||||
weight_decay: float = 0.0,
|
|
||||||
nesterov: bool = True,
|
|
||||||
ns_steps: int = 5,
|
|
||||||
adamw_lr: float = None,
|
|
||||||
adamw_betas: tuple = (0.9, 0.95),
|
|
||||||
adamw_eps: float = 1e-8,
|
|
||||||
adamw_wd: float = 0.0,
|
|
||||||
):
|
|
||||||
defaults = dict(
|
|
||||||
lr=lr,
|
|
||||||
momentum=momentum,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
nesterov=nesterov,
|
|
||||||
ns_steps=ns_steps,
|
|
||||||
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
|
||||||
adamw_betas=adamw_betas,
|
|
||||||
adamw_eps=adamw_eps,
|
|
||||||
adamw_wd=adamw_wd,
|
|
||||||
)
|
|
||||||
super().__init__(params, defaults)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, closure=None):
|
|
||||||
loss = None
|
|
||||||
if closure is not None:
|
|
||||||
with torch.enable_grad():
|
|
||||||
loss = closure()
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
|
||||||
params_2d, params_1d = [], []
|
|
||||||
grads_2d, grads_1d = [], []
|
|
||||||
|
|
||||||
for p in group["params"]:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
if p.grad.is_sparse:
|
|
||||||
raise RuntimeError("Muon does not support sparse gradients")
|
|
||||||
if p.ndim >= 2:
|
|
||||||
params_2d.append(p)
|
|
||||||
grads_2d.append(p.grad)
|
|
||||||
else:
|
|
||||||
params_1d.append(p)
|
|
||||||
grads_1d.append(p.grad)
|
|
||||||
|
|
||||||
if params_2d:
|
|
||||||
self._muon_update_foreach(params_2d, grads_2d, group)
|
|
||||||
if params_1d:
|
|
||||||
self._adamw_update_foreach(params_1d, grads_1d, group)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
|
||||||
lr = group["lr"]
|
|
||||||
momentum = group["momentum"]
|
|
||||||
wd = group["weight_decay"]
|
|
||||||
nesterov = group["nesterov"]
|
|
||||||
ns_steps = group["ns_steps"]
|
|
||||||
|
|
||||||
if wd != 0:
|
|
||||||
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
|
||||||
|
|
||||||
if nesterov:
|
|
||||||
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
|
||||||
|
|
||||||
bufs = []
|
|
||||||
for p, grad in zip(params_2d, grads_2d):
|
|
||||||
state = self.state[p]
|
|
||||||
if "momentum_buffer" not in state:
|
|
||||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
|
||||||
bufs.append(state["momentum_buffer"])
|
|
||||||
|
|
||||||
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
|
||||||
|
|
||||||
for p, buf in zip(params_2d, bufs):
|
|
||||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
|
||||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
|
||||||
p.add_(update, alpha=-lr * scale)
|
|
||||||
|
|
||||||
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
|
||||||
lr = group["adamw_lr"]
|
|
||||||
betas = group["adamw_betas"]
|
|
||||||
eps = group["adamw_eps"]
|
|
||||||
wd = group["adamw_wd"]
|
|
||||||
|
|
||||||
steps: list[int] = []
|
|
||||||
exp_avgs, exp_avg_sqs = [], []
|
|
||||||
has_state = []
|
|
||||||
for p in params_1d:
|
|
||||||
state = self.state[p]
|
|
||||||
if not state:
|
|
||||||
state["step"] = 0
|
|
||||||
state["exp_avg"] = torch.zeros_like(p)
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
||||||
has_state.append(False)
|
|
||||||
else:
|
|
||||||
has_state.append(True)
|
|
||||||
state["step"] += 1
|
|
||||||
steps.append(state["step"])
|
|
||||||
exp_avgs.append(state["exp_avg"])
|
|
||||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
|
||||||
|
|
||||||
beta1, beta2 = betas
|
|
||||||
|
|
||||||
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
|
||||||
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
|
||||||
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
|
||||||
|
|
||||||
bias_correction1 = [1 - beta1**s for s in steps]
|
|
||||||
bias_correction2 = [1 - beta2**s for s in steps]
|
|
||||||
|
|
||||||
if wd != 0:
|
|
||||||
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
|
||||||
|
|
||||||
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
|
||||||
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
|
||||||
denom = torch._foreach_sqrt(denom)
|
|
||||||
torch._foreach_add_(denom, eps)
|
|
||||||
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|
|
||||||
|
|
@ -1,194 +0,0 @@
|
||||||
"""Learning rate scheduler implementations with factory pattern."""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, List, Type
|
|
||||||
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
class BaseScheduler(LRScheduler, ABC):
|
|
||||||
"""Base scheduler class for all other schedulers."""
|
|
||||||
|
|
||||||
def __init__(self, optimizer, last_epoch: int = -1):
|
|
||||||
super().__init__(optimizer, last_epoch)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_lr(self) -> List[float]:
|
|
||||||
"""Calculate the current learning rate."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, Any]:
|
|
||||||
return super().state_dict()
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
|
||||||
super().load_state_dict(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|
||||||
"""Factory class for creating learning rate schedulers.
|
|
||||||
|
|
||||||
Supports decorator-based registration for extensible scheduler types.
|
|
||||||
Also supports creation from ScheduleConfig objects.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
@SchedulerFactory.register("custom")
|
|
||||||
class CustomScheduler(BaseScheduler):
|
|
||||||
...
|
|
||||||
|
|
||||||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
|
||||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
|
||||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls, optimizer, schedule_type: str = "none", **kwargs
|
|
||||||
) -> "BaseScheduler":
|
|
||||||
"""Create a scheduler instance by type name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
optimizer: PyTorch optimizer
|
|
||||||
schedule_type: Type of scheduler ("cosine", "sgdr")
|
|
||||||
**kwargs: Arguments passed to the scheduler constructor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Scheduler instance
|
|
||||||
"""
|
|
||||||
return super().create(schedule_type, optimizer, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available_types(cls) -> list:
|
|
||||||
"""Return list of registered scheduler type names."""
|
|
||||||
return cls.list_registered()
|
|
||||||
|
|
||||||
|
|
||||||
# ----------- Scheduler implementations -----------
|
|
||||||
|
|
||||||
|
|
||||||
@SchedulerFactory.register("cosine")
|
|
||||||
class CosineScheduler(BaseScheduler):
|
|
||||||
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer,
|
|
||||||
warmup_steps: int,
|
|
||||||
lr_decay_steps: int,
|
|
||||||
min_rate: float = 0.05,
|
|
||||||
last_epoch: int = -1,
|
|
||||||
):
|
|
||||||
self.warmup_steps = warmup_steps
|
|
||||||
self.lr_decay_steps = lr_decay_steps
|
|
||||||
self.min_rate = min_rate
|
|
||||||
self.total_steps = warmup_steps + lr_decay_steps
|
|
||||||
super().__init__(optimizer, last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self) -> List[float]:
|
|
||||||
# warmup
|
|
||||||
if self.last_epoch < self.warmup_steps:
|
|
||||||
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
|
||||||
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
|
||||||
|
|
||||||
# cosine decay
|
|
||||||
decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps
|
|
||||||
decay_progress = min(decay_progress, 1.0)
|
|
||||||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
|
|
||||||
decay_factor = max(self.min_rate, cosine_decay)
|
|
||||||
return [base_lr * decay_factor for base_lr in self.base_lrs]
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
state = super().state_dict()
|
|
||||||
state.update(
|
|
||||||
{
|
|
||||||
"warmup_steps": self.warmup_steps,
|
|
||||||
"lr_decay_steps": self.lr_decay_steps,
|
|
||||||
"min_rate": self.min_rate,
|
|
||||||
"total_steps": self.total_steps,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
self.warmup_steps = state_dict.pop("warmup_steps")
|
|
||||||
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
|
|
||||||
self.min_rate = state_dict.pop("min_rate")
|
|
||||||
self.total_steps = state_dict.pop("total_steps")
|
|
||||||
super().load_state_dict(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
@SchedulerFactory.register("sgdr")
|
|
||||||
class SGDRScheduler(BaseScheduler):
|
|
||||||
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer,
|
|
||||||
warmup_steps: int,
|
|
||||||
cycle_length: int,
|
|
||||||
min_rate: float = 0.05,
|
|
||||||
t_mult: int = 2,
|
|
||||||
last_epoch: int = -1,
|
|
||||||
):
|
|
||||||
self.warmup_steps = warmup_steps
|
|
||||||
self.cycle_length = cycle_length
|
|
||||||
self.min_rate = min_rate
|
|
||||||
self.t_mult = t_mult
|
|
||||||
|
|
||||||
super().__init__(optimizer, last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
# warmup
|
|
||||||
if self.last_epoch < self.warmup_steps:
|
|
||||||
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
|
||||||
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
|
||||||
|
|
||||||
# SGDR
|
|
||||||
steps_since_warmup = self.last_epoch - self.warmup_steps
|
|
||||||
|
|
||||||
# 1. Calculate current cycle and position within cycle
|
|
||||||
current_cycle_length = self.cycle_length
|
|
||||||
total_cycles_length = 0
|
|
||||||
cycle_num = 0
|
|
||||||
|
|
||||||
while total_cycles_length + current_cycle_length <= steps_since_warmup:
|
|
||||||
total_cycles_length += current_cycle_length
|
|
||||||
current_cycle_length *= self.t_mult
|
|
||||||
cycle_num += 1
|
|
||||||
|
|
||||||
steps_in_cycle = steps_since_warmup - total_cycles_length
|
|
||||||
|
|
||||||
# 2. Cosine annealing within the current cycle
|
|
||||||
cosine_factor = 0.5 * (
|
|
||||||
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
|
|
||||||
)
|
|
||||||
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
|
||||||
|
|
||||||
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
"""Returns the state of the scheduler as a dict."""
|
|
||||||
state = super().state_dict()
|
|
||||||
state.update(
|
|
||||||
{
|
|
||||||
"warmup_steps": self.warmup_steps,
|
|
||||||
"cycle_length": self.cycle_length,
|
|
||||||
"min_rate": self.min_rate,
|
|
||||||
"t_mult": self.t_mult,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
"""Loads the scheduler's state."""
|
|
||||||
self.warmup_steps = state_dict.pop("warmup_steps")
|
|
||||||
self.cycle_length = state_dict.pop("cycle_length")
|
|
||||||
self.min_rate = state_dict.pop("min_rate")
|
|
||||||
self.t_mult = state_dict.pop("t_mult")
|
|
||||||
super().load_state_dict(state_dict)
|
|
||||||
|
|
@ -1,334 +0,0 @@
|
||||||
"""Training strategy implementations with factory pattern."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Callable, Dict, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
|
||||||
"""Create a frozen reference model from model_fn + full state dict."""
|
|
||||||
ref_model = model_fn()
|
|
||||||
ref_model.load_state_dict(state_dict)
|
|
||||||
ref_model.requires_grad_(False)
|
|
||||||
ref_model.eval()
|
|
||||||
return ref_model
|
|
||||||
|
|
||||||
|
|
||||||
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
|
||||||
"""Move batch tensors to specified device with non-blocking transfer."""
|
|
||||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def get_logprobs(
|
|
||||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
|
||||||
input_ids: Tensor,
|
|
||||||
mask: Tensor,
|
|
||||||
reduction: str,
|
|
||||||
):
|
|
||||||
"""Compute token-wise log probabilities from model outputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The language model
|
|
||||||
input_ids: Input token IDs of shape [batch_size, seq_len]
|
|
||||||
mask: Attention mask of shape [batch_size, seq_len]
|
|
||||||
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Log probabilities with reduction applied over sequence dimension
|
|
||||||
"""
|
|
||||||
allowed_reductions = ["mean", "sum", "none"]
|
|
||||||
if reduction not in allowed_reductions:
|
|
||||||
raise ValueError(
|
|
||||||
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
shifted_input_ids = input_ids[:, 1:]
|
|
||||||
shifted_mask = mask[:, 1:]
|
|
||||||
|
|
||||||
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
|
||||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
|
||||||
|
|
||||||
token_logprobs = torch.gather(
|
|
||||||
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
|
|
||||||
).squeeze(-1)
|
|
||||||
|
|
||||||
if reduction == "mean":
|
|
||||||
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
|
|
||||||
dim=-1
|
|
||||||
).clamp(min=1.0)
|
|
||||||
elif reduction == "sum":
|
|
||||||
return (token_logprobs * shifted_mask).sum(dim=-1)
|
|
||||||
else:
|
|
||||||
return token_logprobs * shifted_mask
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
|
||||||
"""Abstract base class for training strategies."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.device = device
|
|
||||||
self.executor = kwargs.pop("executor", None)
|
|
||||||
self.model_fn = kwargs.pop("model_fn", None)
|
|
||||||
self.extra_kwargs = kwargs
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Compute loss for the given batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: Dictionary containing batch tensors
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Computed loss tensor
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Allow calling strategy directly as a callable."""
|
|
||||||
return self.compute_loss(batch)
|
|
||||||
|
|
||||||
|
|
||||||
class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|
||||||
"""Factory class for creating training strategy instances.
|
|
||||||
|
|
||||||
Supports decorator-based registration for extensible strategy types.
|
|
||||||
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
@StrategyFactory.register("custom")
|
|
||||||
class CustomStrategy(BaseStrategy):
|
|
||||||
...
|
|
||||||
|
|
||||||
strategy = StrategyFactory.create("custom", model, device)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, strategy_cls: type):
|
|
||||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
|
||||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
|
||||||
"""Create a strategy instance based on training type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
|
||||||
model: Model instance for the strategy
|
|
||||||
device: Device to run the strategy on
|
|
||||||
**kwargs: Additional arguments passed to strategy constructor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Strategy instance
|
|
||||||
"""
|
|
||||||
return super().create(train_type, model, device, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available_strategies(cls) -> list:
|
|
||||||
"""Return list of registered strategy names."""
|
|
||||||
return cls.list_registered()
|
|
||||||
|
|
||||||
|
|
||||||
# ============== Strategy Classes ==============
|
|
||||||
# All strategies are registered at class definition time using the decorator
|
|
||||||
|
|
||||||
|
|
||||||
@StrategyFactory.register("seq")
|
|
||||||
class SEQStrategy(BaseStrategy):
|
|
||||||
"""Standard next-token prediction training strategy.
|
|
||||||
|
|
||||||
Computes cross-entropy loss for next token prediction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
|
||||||
super().__init__(model, device, **kwargs)
|
|
||||||
self.label_smoothing = label_smoothing
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
||||||
batch = move_to_device(batch, self.device)
|
|
||||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
|
||||||
input=logits.flatten(0, 1).float(),
|
|
||||||
target=target_ids.flatten(),
|
|
||||||
label_smoothing=self.label_smoothing,
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
@StrategyFactory.register("sft")
|
|
||||||
class SFTStrategy(BaseStrategy):
|
|
||||||
"""Supervised Fine-tuning strategy with loss masking.
|
|
||||||
|
|
||||||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
|
||||||
super().__init__(model, device, **kwargs)
|
|
||||||
self.label_smoothing = label_smoothing
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
||||||
batch = move_to_device(batch, self.device)
|
|
||||||
input_ids, target_ids, loss_mask = (
|
|
||||||
batch["input_ids"],
|
|
||||||
batch["target_ids"],
|
|
||||||
batch["loss_mask"],
|
|
||||||
)
|
|
||||||
|
|
||||||
ignore_index = -100
|
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
|
||||||
input=logits.flatten(0, 1).float(),
|
|
||||||
target=target_ids.flatten(),
|
|
||||||
ignore_index=ignore_index,
|
|
||||||
label_smoothing=self.label_smoothing,
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
@StrategyFactory.register("dpo")
|
|
||||||
class DPOStrategy(BaseStrategy):
|
|
||||||
"""Direct Preference Optimization strategy.
|
|
||||||
|
|
||||||
Implements the DPO loss from the paper "Direct Preference Optimization".
|
|
||||||
Uses a reference model to compute KL divergence penalty.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
device: str,
|
|
||||||
beta: float = 0.1,
|
|
||||||
reduction: str = "mean",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(model, device, **kwargs)
|
|
||||||
self.ref_model = create_ref_model(
|
|
||||||
self.model_fn, self.executor.unwrap_model(model)
|
|
||||||
).to(device=self.device)
|
|
||||||
self.beta = beta
|
|
||||||
self.reduction = reduction
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
||||||
batch = move_to_device(batch, self.device)
|
|
||||||
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
|
||||||
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
|
||||||
|
|
||||||
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
|
||||||
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
|
||||||
|
|
||||||
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
log_ref = get_logprobs(
|
|
||||||
self.ref_model, concat_ids, concat_mask, self.reduction
|
|
||||||
)
|
|
||||||
|
|
||||||
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
|
||||||
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
|
|
||||||
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
|
|
||||||
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
|
|
||||||
|
|
||||||
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
|
||||||
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
|
||||||
|
|
||||||
ratio_diff = pi_log_ratio - ref_log_ratio
|
|
||||||
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
|
||||||
|
|
||||||
return dpo_loss
|
|
||||||
|
|
||||||
|
|
||||||
@StrategyFactory.register("grpo")
|
|
||||||
class GRPOStrategy(BaseStrategy):
|
|
||||||
"""Group Relative Policy Optimization strategy.
|
|
||||||
|
|
||||||
On-policy GRPO following DeepSeek-R1: the policy model is updated while
|
|
||||||
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
|
|
||||||
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
device: str,
|
|
||||||
clip_eps: float = 0.2,
|
|
||||||
kl_coef: float = 0.01,
|
|
||||||
group_size: int = 4,
|
|
||||||
reduction: str = "mean",
|
|
||||||
sync_interval: int = 200,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(model, device, **kwargs)
|
|
||||||
self.ref_model = create_ref_model(
|
|
||||||
self.model_fn, self.executor.unwrap_model(model)
|
|
||||||
).to(device=self.device)
|
|
||||||
self.clip_eps = clip_eps
|
|
||||||
self.kl_coef = kl_coef
|
|
||||||
self.group_size = group_size
|
|
||||||
self.reduction = reduction
|
|
||||||
self.sync_interval = sync_interval
|
|
||||||
self._step = 0
|
|
||||||
|
|
||||||
def sync_ref_model(self):
|
|
||||||
"""Copy current model weights to ref model."""
|
|
||||||
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
||||||
self._step += 1
|
|
||||||
if self._step % self.sync_interval == 0:
|
|
||||||
self.sync_ref_model()
|
|
||||||
|
|
||||||
batch = move_to_device(batch, self.device)
|
|
||||||
prompts = batch["prompts"]
|
|
||||||
responses = batch["responses"]
|
|
||||||
masks = batch["masks"]
|
|
||||||
rewards = batch["rewards"]
|
|
||||||
|
|
||||||
batch_size, group_size, response_len = responses.shape
|
|
||||||
responses_flat = responses.view(-1, response_len)
|
|
||||||
masks_flat = masks.view(-1, response_len)
|
|
||||||
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
|
||||||
|
|
||||||
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
|
||||||
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
|
||||||
|
|
||||||
log_probs_policy = get_logprobs(
|
|
||||||
self.model, full_sequences, full_masks, self.reduction
|
|
||||||
)
|
|
||||||
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
log_probs_ref = get_logprobs(
|
|
||||||
self.ref_model, full_sequences, full_masks, self.reduction
|
|
||||||
)
|
|
||||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
|
||||||
|
|
||||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
|
||||||
mean = rewards.mean(dim=-1, keepdim=True)
|
|
||||||
std = rewards.std(dim=-1, keepdim=True)
|
|
||||||
advantages = (rewards - mean) / (std + eps)
|
|
||||||
|
|
||||||
ratio = torch.exp(log_probs_policy - log_probs_ref)
|
|
||||||
|
|
||||||
surr1 = ratio * advantages
|
|
||||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
|
||||||
|
|
||||||
policy_loss = -torch.min(surr1, surr2).mean()
|
|
||||||
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
|
||||||
total_loss = policy_loss + kl_penalty
|
|
||||||
|
|
||||||
return total_loss
|
|
||||||
|
|
@ -1,333 +0,0 @@
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.utils import clip_grad_norm_
|
|
||||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.parallel import only_on_rank
|
|
||||||
from astrai.parallel.setup import get_current_device, get_rank
|
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.trainer.metric_util import (
|
|
||||||
ctx_get_grad_max,
|
|
||||||
ctx_get_grad_mean,
|
|
||||||
ctx_get_grad_min,
|
|
||||||
ctx_get_grad_nan_num,
|
|
||||||
ctx_get_grad_norm,
|
|
||||||
ctx_get_grad_std,
|
|
||||||
ctx_get_loss,
|
|
||||||
ctx_get_lr,
|
|
||||||
ctx_get_val_loss,
|
|
||||||
)
|
|
||||||
from astrai.trainer.train_context import TrainContext
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class TrainCallback(Protocol):
|
|
||||||
"""
|
|
||||||
Callback interface for trainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
|
||||||
"""Called at the beginning of training."""
|
|
||||||
|
|
||||||
def on_train_end(self, context: TrainContext):
|
|
||||||
"""Called at the end of training."""
|
|
||||||
|
|
||||||
def on_epoch_begin(self, context: TrainContext):
|
|
||||||
"""Called at the beginning of each epoch."""
|
|
||||||
|
|
||||||
def on_epoch_end(self, context: TrainContext):
|
|
||||||
"""Called at the end of each epoch."""
|
|
||||||
|
|
||||||
def on_batch_begin(self, context: TrainContext):
|
|
||||||
"""Called at the beginning of each batch."""
|
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
|
||||||
"""Called at the end of each batch."""
|
|
||||||
|
|
||||||
def on_optimizer_step(self, context: TrainContext):
|
|
||||||
"""Called on every optimizer step (sync step only)."""
|
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
|
||||||
"""Called when an error occurs during training."""
|
|
||||||
|
|
||||||
|
|
||||||
class CallbackFactory(BaseFactory[TrainCallback]):
|
|
||||||
"""Factory for registering and creating training callbacks.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@CallbackFactory.register("my_callback")
|
|
||||||
class MyCallback(TrainCallback):
|
|
||||||
...
|
|
||||||
|
|
||||||
callback = CallbackFactory.create("my_callback", **kwargs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("gradient_clipping")
|
|
||||||
class GradientClippingCallback(TrainCallback):
|
|
||||||
"""
|
|
||||||
Gradient clipping callback for trainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, max_grad_norm: float):
|
|
||||||
self.max_grad_norm = max_grad_norm
|
|
||||||
|
|
||||||
def on_optimizer_step(self, context: TrainContext):
|
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("gradient_checkpointing")
|
|
||||||
class GradientCheckpointingCallback(TrainCallback):
|
|
||||||
"""
|
|
||||||
Activation checkpointing callback — trades compute for memory
|
|
||||||
by recomputing specified module activations during the backward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
modules: Module types to apply checkpointing to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, modules: Optional[List[type]] = None):
|
|
||||||
self.modules = tuple(modules) if modules else ()
|
|
||||||
|
|
||||||
def _enable(self, module: nn.Module):
|
|
||||||
if self.modules and isinstance(module, self.modules):
|
|
||||||
fn = module.forward
|
|
||||||
module._original_forward = fn
|
|
||||||
module.forward = lambda *a, **kw: torch_checkpoint(
|
|
||||||
fn, *a, use_reentrant=False, **kw
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _disable(module: nn.Module):
|
|
||||||
if hasattr(module, "_original_forward"):
|
|
||||||
module.forward = module._original_forward
|
|
||||||
del module._original_forward
|
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
|
||||||
context.model.apply(self._enable)
|
|
||||||
logger.info("Gradient checkpointing enabled")
|
|
||||||
|
|
||||||
def on_train_end(self, context: TrainContext):
|
|
||||||
context.model.apply(self._disable)
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("checkpoint")
|
|
||||||
class CheckpointCallback(TrainCallback):
|
|
||||||
"""
|
|
||||||
Checkpoint callback for trainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
extra_keys = ("optimizer", "scheduler")
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
save_dir: str,
|
|
||||||
interval: int,
|
|
||||||
weight_only: bool = False,
|
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
|
||||||
):
|
|
||||||
self.save_dir = save_dir
|
|
||||||
self.interval = interval
|
|
||||||
self.weight_only = weight_only
|
|
||||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
|
||||||
self.last_ckpt_iter = 0
|
|
||||||
|
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
|
||||||
state_dict = context.executor.unwrap_model(context.model)
|
|
||||||
self.last_ckpt_iter = context.iteration
|
|
||||||
|
|
||||||
if get_rank() == 0:
|
|
||||||
save_path = os.path.join(
|
|
||||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
|
||||||
)
|
|
||||||
extra = self.save_extra_fn(context)
|
|
||||||
context.checkpoint = Checkpoint(
|
|
||||||
state_dict=state_dict,
|
|
||||||
epoch=context.epoch,
|
|
||||||
iteration=context.iteration,
|
|
||||||
extra=extra,
|
|
||||||
config=context.model_config,
|
|
||||||
)
|
|
||||||
context.checkpoint.save(save_path)
|
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
|
||||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
|
||||||
self._save_checkpoint(context)
|
|
||||||
|
|
||||||
def on_train_end(self, context: TrainContext):
|
|
||||||
if context.iteration != self.last_ckpt_iter:
|
|
||||||
self._save_checkpoint(context)
|
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
|
||||||
self._save_checkpoint(context)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save_extra(context: TrainContext) -> dict:
|
|
||||||
extra = {}
|
|
||||||
for name in CheckpointCallback.extra_keys:
|
|
||||||
obj = getattr(context, name, None)
|
|
||||||
if obj:
|
|
||||||
extra[name] = obj.state_dict()
|
|
||||||
return extra
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("progress_bar")
|
|
||||||
class ProgressBarCallback(TrainCallback):
|
|
||||||
"""
|
|
||||||
Progress bar callback for trainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
|
||||||
):
|
|
||||||
self.num_epoch = num_epoch
|
|
||||||
self.log_interval = log_interval
|
|
||||||
self.file = file
|
|
||||||
self.progress_bar: tqdm = None
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def on_epoch_begin(self, context: TrainContext):
|
|
||||||
self.progress_bar = tqdm(
|
|
||||||
context.dataloader,
|
|
||||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
|
||||||
dynamic_ncols=True,
|
|
||||||
file=self.file or sys.stdout,
|
|
||||||
)
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
|
||||||
postfix = {
|
|
||||||
"loss": f"{context.loss:.4f}",
|
|
||||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
|
||||||
}
|
|
||||||
if context.val_loss > 0:
|
|
||||||
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
|
||||||
self.progress_bar.set_postfix(postfix)
|
|
||||||
self.progress_bar.update(1)
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def on_epoch_end(self, context: TrainContext):
|
|
||||||
_ = context
|
|
||||||
if self.progress_bar:
|
|
||||||
self.progress_bar.close()
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("metric_logger")
|
|
||||||
class MetricLoggerCallback(TrainCallback):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
log_dir: str,
|
|
||||||
save_interval: int,
|
|
||||||
log_interval: int = 10,
|
|
||||||
metrics: List[str] = None,
|
|
||||||
):
|
|
||||||
self.last_log_iter = 0
|
|
||||||
self.save_interval = save_interval
|
|
||||||
self.log_interval = log_interval
|
|
||||||
self.metrics = metrics or ["loss", "lr"]
|
|
||||||
|
|
||||||
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
|
||||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
self.log_cache = []
|
|
||||||
|
|
||||||
self._metric_funcs = {
|
|
||||||
"loss": ctx_get_loss,
|
|
||||||
"lr": ctx_get_lr,
|
|
||||||
"val_loss": ctx_get_val_loss,
|
|
||||||
"grad_norm": ctx_get_grad_norm,
|
|
||||||
"grad_std": ctx_get_grad_std,
|
|
||||||
"grad_max": ctx_get_grad_max,
|
|
||||||
"grad_min": ctx_get_grad_min,
|
|
||||||
"grad_mean": ctx_get_grad_mean,
|
|
||||||
"grad_nan_num": ctx_get_grad_nan_num,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
|
||||||
return {
|
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
||||||
"epoch": context.epoch,
|
|
||||||
"iter": context.iteration,
|
|
||||||
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
|
||||||
}
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def _add_log(self, log_data):
|
|
||||||
self.log_cache.append(log_data)
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def _save_log(self, epoch, iter):
|
|
||||||
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
|
||||||
|
|
||||||
with open(log_file, "w") as f:
|
|
||||||
for log in self.log_cache:
|
|
||||||
f.write(json.dumps(log) + "\n")
|
|
||||||
|
|
||||||
def on_batch_end(self, context):
|
|
||||||
if context.iteration % self.log_interval == 0:
|
|
||||||
log_data = self._get_log_data(context)
|
|
||||||
self._add_log(log_data)
|
|
||||||
|
|
||||||
if context.iteration - self.last_log_iter >= self.save_interval:
|
|
||||||
self._save_log(context.epoch, context.iteration)
|
|
||||||
self.last_log_iter = context.iteration
|
|
||||||
|
|
||||||
def on_train_end(self, context):
|
|
||||||
if context.iteration != self.last_log_iter:
|
|
||||||
self._save_log(context.epoch, context.iteration)
|
|
||||||
|
|
||||||
def on_error(self, context):
|
|
||||||
self._save_log(context.epoch, context.iteration)
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("validation")
|
|
||||||
class ValidationCallback(TrainCallback):
|
|
||||||
def _run_validation(self, context: TrainContext):
|
|
||||||
context.model.eval()
|
|
||||||
|
|
||||||
total_loss = 0.0
|
|
||||||
num_batches = 0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in context.val_dataloader:
|
|
||||||
loss = context.strategy(batch)
|
|
||||||
total_loss += loss.item()
|
|
||||||
num_batches += 1
|
|
||||||
|
|
||||||
avg_loss = total_loss / max(num_batches, 1)
|
|
||||||
|
|
||||||
if context.world_size > 1 and dist.is_initialized():
|
|
||||||
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
|
|
||||||
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
||||||
avg_loss = loss_tensor.item()
|
|
||||||
|
|
||||||
context.val_loss = avg_loss
|
|
||||||
context.model.train()
|
|
||||||
|
|
||||||
step_count = context.iteration // context.config.grad_accum_steps
|
|
||||||
logger.info(
|
|
||||||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_optimizer_step(self, context: TrainContext):
|
|
||||||
if context.val_dataloader is None:
|
|
||||||
return
|
|
||||||
cfg = context.config
|
|
||||||
if cfg.val_step <= 0:
|
|
||||||
return
|
|
||||||
step_count = context.iteration // cfg.grad_accum_steps
|
|
||||||
if step_count % cfg.val_step == 0:
|
|
||||||
self._run_validation(context)
|
|
||||||
|
|
@ -1,170 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Self
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
|
||||||
from astrai.dataset import ResumableDistributedSampler
|
|
||||||
from astrai.model.components.lora import inject_lora
|
|
||||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
|
||||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
|
||||||
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainContext:
|
|
||||||
model: nn.Module = field(default=None)
|
|
||||||
strategy: BaseStrategy = field(default=None)
|
|
||||||
dataloader: DataLoader = field(default=None)
|
|
||||||
optimizer: OptimizerProtocol = field(default=None)
|
|
||||||
scheduler: SchedulerProtocol = field(default=None)
|
|
||||||
checkpoint: Checkpoint = field(default=None)
|
|
||||||
config: TrainConfig = field(default=None)
|
|
||||||
model_config: dict = field(default_factory=dict)
|
|
||||||
executor: BaseExecutor = field(default=None)
|
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
|
||||||
iteration: int = field(default=0)
|
|
||||||
loss: float = field(default=0.0)
|
|
||||||
val_dataloader: DataLoader = field(default=None)
|
|
||||||
val_loss: float = field(default=0.0)
|
|
||||||
|
|
||||||
world_size: int = field(default=1)
|
|
||||||
rank: int = field(default=0)
|
|
||||||
kwargs: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: TrainConfig,
|
|
||||||
):
|
|
||||||
self.config = config
|
|
||||||
self._resume_dir: Optional[str] = None
|
|
||||||
|
|
||||||
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
|
||||||
self._resume_dir = resume_dir
|
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
|
||||||
cfg = self.config
|
|
||||||
device = get_current_device()
|
|
||||||
|
|
||||||
executor = ExecutorFactory.create(
|
|
||||||
cfg.parallel_mode,
|
|
||||||
grad_accum_steps=cfg.grad_accum_steps,
|
|
||||||
**cfg.executor_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = cfg.model_fn()
|
|
||||||
model = model.to(device=device)
|
|
||||||
|
|
||||||
model_config = {}
|
|
||||||
if self._resume_dir:
|
|
||||||
config_path = Path(self._resume_dir) / "config.json"
|
|
||||||
if config_path.exists():
|
|
||||||
model_config = load_json(config_path)
|
|
||||||
|
|
||||||
if not model_config and hasattr(model, "config"):
|
|
||||||
model_config = model.config.to_dict()
|
|
||||||
|
|
||||||
context = TrainContext(
|
|
||||||
model=model,
|
|
||||||
world_size=get_world_size(),
|
|
||||||
rank=get_rank(),
|
|
||||||
config=cfg,
|
|
||||||
model_config=model_config,
|
|
||||||
executor=executor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._resume_dir is not None:
|
|
||||||
resume_path = Path(self._resume_dir)
|
|
||||||
if (resume_path / "meta.json").exists():
|
|
||||||
checkpoint = Checkpoint.load(self._resume_dir)
|
|
||||||
state_dict = checkpoint.state_dict
|
|
||||||
if checkpoint.config:
|
|
||||||
context.model_config = checkpoint.config
|
|
||||||
else:
|
|
||||||
checkpoint = None
|
|
||||||
state_dict = load_model_weights(self._resume_dir)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
if checkpoint is not None:
|
|
||||||
context.epoch = cfg.start_epoch
|
|
||||||
context.iteration = cfg.start_batch
|
|
||||||
context.checkpoint = checkpoint
|
|
||||||
|
|
||||||
if cfg.lora is not None:
|
|
||||||
inject_lora(
|
|
||||||
model,
|
|
||||||
r=cfg.lora.r,
|
|
||||||
alpha=cfg.lora.alpha,
|
|
||||||
target_modules=set(cfg.lora.target_modules),
|
|
||||||
)
|
|
||||||
|
|
||||||
context.optimizer = cfg.optimizer_fn(model)
|
|
||||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
|
||||||
|
|
||||||
sampler_offset = context.iteration * cfg.batch_per_device
|
|
||||||
sampler = ResumableDistributedSampler(
|
|
||||||
data_source=cfg.dataset,
|
|
||||||
start_epoch=context.epoch,
|
|
||||||
start_iter=sampler_offset,
|
|
||||||
seed=cfg.random_seed,
|
|
||||||
)
|
|
||||||
context.dataloader = DataLoader(
|
|
||||||
cfg.dataset,
|
|
||||||
batch_size=cfg.batch_per_device,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=cfg.num_workers,
|
|
||||||
pin_memory=cfg.pin_memory,
|
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.val_dataset is not None:
|
|
||||||
val_sampler = ResumableDistributedSampler(
|
|
||||||
data_source=cfg.val_dataset,
|
|
||||||
start_epoch=0,
|
|
||||||
start_iter=0,
|
|
||||||
seed=cfg.random_seed,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
context.val_dataloader = DataLoader(
|
|
||||||
cfg.val_dataset,
|
|
||||||
batch_size=cfg.batch_per_device,
|
|
||||||
sampler=val_sampler,
|
|
||||||
num_workers=cfg.num_workers,
|
|
||||||
pin_memory=cfg.pin_memory,
|
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
|
||||||
executor.prepare(
|
|
||||||
model,
|
|
||||||
context.optimizer,
|
|
||||||
context.dataloader,
|
|
||||||
context.scheduler,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if context.checkpoint and context.checkpoint.extra:
|
|
||||||
extra = context.checkpoint.extra
|
|
||||||
for name in ("optimizer", "scheduler"):
|
|
||||||
if name in extra:
|
|
||||||
obj = getattr(context, name, None)
|
|
||||||
if obj is not None:
|
|
||||||
obj.load_state_dict(extra[name])
|
|
||||||
|
|
||||||
context.strategy = StrategyFactory.create(
|
|
||||||
model=context.model,
|
|
||||||
train_type=cfg.strategy,
|
|
||||||
device=device,
|
|
||||||
executor=executor,
|
|
||||||
model_fn=cfg.model_fn,
|
|
||||||
**cfg.extra_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
@ -1,109 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
|
||||||
from astrai.parallel.setup import spawn_parallel_fn
|
|
||||||
from astrai.trainer.train_callback import (
|
|
||||||
CallbackFactory,
|
|
||||||
TrainCallback,
|
|
||||||
)
|
|
||||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
def __init__(
|
|
||||||
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
|
|
||||||
):
|
|
||||||
self.train_config = train_config
|
|
||||||
default_callbacks = self._get_default_callbacks()
|
|
||||||
self.callbacks = (
|
|
||||||
default_callbacks + callbacks if callbacks else default_callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
|
||||||
cfg = self.train_config
|
|
||||||
callbacks = [
|
|
||||||
CallbackFactory.create(
|
|
||||||
"gradient_checkpointing",
|
|
||||||
modules=cfg.gradient_checkpointing_modules,
|
|
||||||
),
|
|
||||||
CallbackFactory.create(
|
|
||||||
"checkpoint",
|
|
||||||
cfg.ckpt_dir,
|
|
||||||
cfg.ckpt_interval,
|
|
||||||
),
|
|
||||||
CallbackFactory.create(
|
|
||||||
"metric_logger",
|
|
||||||
log_dir=cfg.log_dir,
|
|
||||||
save_interval=cfg.ckpt_interval,
|
|
||||||
log_interval=cfg.log_interval,
|
|
||||||
metrics=cfg.metrics,
|
|
||||||
),
|
|
||||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
|
||||||
CallbackFactory.create("validation"),
|
|
||||||
]
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
|
||||||
for callback in self.callbacks:
|
|
||||||
method = getattr(callback, method_name, None)
|
|
||||||
if method:
|
|
||||||
method(context)
|
|
||||||
|
|
||||||
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
|
||||||
context = (
|
|
||||||
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
|
||||||
)
|
|
||||||
executor = context.executor
|
|
||||||
self._call_callbacks("on_train_begin", context)
|
|
||||||
|
|
||||||
try:
|
|
||||||
context.model.train()
|
|
||||||
|
|
||||||
for epoch in range(context.epoch, context.config.n_epoch):
|
|
||||||
context.epoch = epoch
|
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
|
||||||
|
|
||||||
for batch in context.dataloader:
|
|
||||||
self._call_callbacks("on_batch_begin", context)
|
|
||||||
|
|
||||||
with executor.accumulate(context.model):
|
|
||||||
loss = context.strategy(batch)
|
|
||||||
context.loss = loss.item()
|
|
||||||
stand_loss = loss / executor.grad_accum_steps
|
|
||||||
executor.backward(stand_loss)
|
|
||||||
context.iteration += 1
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
|
||||||
|
|
||||||
if executor.sync_gradients:
|
|
||||||
self._call_callbacks("on_optimizer_step", context)
|
|
||||||
context.optimizer.step()
|
|
||||||
context.optimizer.zero_grad()
|
|
||||||
|
|
||||||
if context.scheduler:
|
|
||||||
context.scheduler.step()
|
|
||||||
|
|
||||||
self._call_callbacks("on_epoch_end", context)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Training failed: %s", str(e), exc_info=True)
|
|
||||||
self._call_callbacks("on_error", context)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._call_callbacks("on_train_end", context)
|
|
||||||
|
|
||||||
def train(self, resume_dir: Optional[str] = None):
|
|
||||||
cfg = self.train_config
|
|
||||||
spawn_parallel_fn(
|
|
||||||
self._trainer_loop,
|
|
||||||
backend=cfg.backend,
|
|
||||||
world_size=cfg.nprocs,
|
|
||||||
master_addr=cfg.master_addr,
|
|
||||||
master_port=cfg.master_port,
|
|
||||||
device_type=cfg.device_type,
|
|
||||||
start_method=cfg.start_method,
|
|
||||||
resume_dir=resume_dir,
|
|
||||||
)
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
import os
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="ViperEk/KHAOSZ",
|
||||||
|
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
||||||
|
force_download=True
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def generate_text():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
query = input(">> ")
|
||||||
|
|
||||||
|
response = model.text_generate(
|
||||||
|
query=query,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generate_text()
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def batch_generate():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
||||||
|
|
||||||
|
responses = model.batch_generate(
|
||||||
|
queries=inputs,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
)
|
||||||
|
|
||||||
|
for q, r in zip(inputs, responses):
|
||||||
|
print((q, r))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_generate()
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz, SemanticTextSplitter, Retriever
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
context_path = os.path.join(PROJECT_ROOT, "README.md")
|
||||||
|
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
spliter = SemanticTextSplitter(model.encode)
|
||||||
|
retriever = Retriever()
|
||||||
|
text = open(context_path, "r", encoding="utf-8").read()
|
||||||
|
|
||||||
|
res = spliter.split(text, threshold=0.8, window_size=1)
|
||||||
|
# print(("\n" + "+"*100 + "\n").join(res))
|
||||||
|
|
||||||
|
res_embs = model.encode(res)
|
||||||
|
for sentence, emb in zip(res, res_embs):
|
||||||
|
retriever.add_vector(sentence, emb)
|
||||||
|
|
||||||
|
retrive_top_k = 5
|
||||||
|
query = "作者设计了一个怎样的模型"
|
||||||
|
emb_query = model.encode(query)
|
||||||
|
retrieved = retriever.retrieve(emb_query, retrive_top_k)
|
||||||
|
|
||||||
|
retrive_response = model.retrieve_generate(
|
||||||
|
retrieved=retrieved,
|
||||||
|
query=query,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
)
|
||||||
|
|
||||||
|
print("retrieve content:")
|
||||||
|
print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]))
|
||||||
|
|
||||||
|
print("\n\nretrive generate:")
|
||||||
|
print(retrive_response)
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def chat():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
history = []
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response_size = 0
|
||||||
|
for response, history in model.stream_generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
):
|
||||||
|
print(response[response_size:], end="", flush=True)
|
||||||
|
response_size = len(response)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
chat()
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
services:
|
|
||||||
server:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
user: "${UID:-1000}:${GID:-1000}"
|
|
||||||
ports:
|
|
||||||
- "8000:8000"
|
|
||||||
volumes:
|
|
||||||
- ./params:/app/params:ro
|
|
||||||
command: python -m scripts.tools.server --port 8000 --device cuda
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 60s
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
server-cpu:
|
|
||||||
profiles: [cpu]
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
user: "${UID:-1000}:${GID:-1000}"
|
|
||||||
ports:
|
|
||||||
- "8000:8000"
|
|
||||||
volumes:
|
|
||||||
- ./params:/app/params:ro
|
|
||||||
command: python -m scripts.tools.server --port 8000 --device cpu
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 120s
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
__version__ = "1.3.2"
|
||||||
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
|
from khaosz.api import Khaosz
|
||||||
|
from khaosz.config import (
|
||||||
|
ModelConfig,
|
||||||
|
TrainConfig,
|
||||||
|
)
|
||||||
|
from khaosz.model.transformer import Transformer
|
||||||
|
from khaosz.utils.retriever import Retriever
|
||||||
|
from khaosz.utils.splitter import (
|
||||||
|
SemanticTextSplitter,
|
||||||
|
PriorityTextSplitter
|
||||||
|
)
|
||||||
|
from khaosz.data import (
|
||||||
|
DatasetLoader,
|
||||||
|
BpeTokenizer
|
||||||
|
)
|
||||||
|
from khaosz.inference.generator import (
|
||||||
|
TextGenerator,
|
||||||
|
ChatGenerator,
|
||||||
|
StreamGenerator,
|
||||||
|
BatchGenerator,
|
||||||
|
RetrievalGenerator,
|
||||||
|
EmbeddingEncoder
|
||||||
|
)
|
||||||
|
|
||||||
|
from khaosz.trainer import (
|
||||||
|
Trainer,
|
||||||
|
StrategyFactory,
|
||||||
|
SchedulerFactory
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Khaosz",
|
||||||
|
|
||||||
|
"Transformer",
|
||||||
|
|
||||||
|
"Retriever",
|
||||||
|
"SemanticTextSplitter",
|
||||||
|
"PriorityTextSplitter",
|
||||||
|
|
||||||
|
"ModelConfig",
|
||||||
|
"TrainConfig",
|
||||||
|
|
||||||
|
"DatasetLoader",
|
||||||
|
"BpeTokenizer",
|
||||||
|
|
||||||
|
"TextGenerator",
|
||||||
|
"ChatGenerator",
|
||||||
|
"StreamGenerator",
|
||||||
|
"BatchGenerator",
|
||||||
|
"RetrievalGenerator",
|
||||||
|
"EmbeddingEncoder",
|
||||||
|
|
||||||
|
"Trainer",
|
||||||
|
"StrategyFactory",
|
||||||
|
"SchedulerFactory"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import List, Tuple, Generator, Union
|
||||||
|
|
||||||
|
from khaosz.inference.generator import (
|
||||||
|
TextGenerator,
|
||||||
|
ChatGenerator,
|
||||||
|
StreamGenerator,
|
||||||
|
BatchGenerator,
|
||||||
|
RetrievalGenerator,
|
||||||
|
EmbeddingEncoder
|
||||||
|
)
|
||||||
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
|
class Khaosz:
|
||||||
|
def __init__(self, model_dir: str):
|
||||||
|
self.parameter = ModelParameter()
|
||||||
|
self.parameter.load(model_dir)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
self.parameter.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]]=None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> str:
|
||||||
|
generator = ChatGenerator(self.parameter)
|
||||||
|
return generator.generate(
|
||||||
|
query,
|
||||||
|
history=history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_generate(
|
||||||
|
self,
|
||||||
|
queries: List[str],
|
||||||
|
histories: List[Tuple[str, str]]=None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> List[str]:
|
||||||
|
generator = BatchGenerator(self.parameter)
|
||||||
|
return generator.generate(
|
||||||
|
queries,
|
||||||
|
histories=histories,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]]=None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||||
|
stream_generator = StreamGenerator(self.parameter)
|
||||||
|
return stream_generator.generate(
|
||||||
|
query,
|
||||||
|
history=history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve_generate(
|
||||||
|
self,
|
||||||
|
retrieved,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]] = None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> str:
|
||||||
|
generator = RetrievalGenerator(self.parameter)
|
||||||
|
return generator.generate(
|
||||||
|
retrieved,
|
||||||
|
query,
|
||||||
|
history=history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def text_generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> str:
|
||||||
|
generator = TextGenerator(self.parameter)
|
||||||
|
|
||||||
|
return generator.generate(
|
||||||
|
query,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
|
encoder = EmbeddingEncoder(self.parameter)
|
||||||
|
return encoder.encode(sentence)
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
from khaosz.config.model_config import ModelConfig
|
||||||
|
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||||
|
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
||||||
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseModelIO",
|
||||||
|
"ModelParameter",
|
||||||
|
"ModelConfig",
|
||||||
|
"TrainConfig",
|
||||||
|
|
||||||
|
"ScheduleConfig",
|
||||||
|
"CosineScheduleConfig",
|
||||||
|
"SGDRScheduleConfig",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
# basic config
|
||||||
|
vocab_size: Optional[int] = None
|
||||||
|
dim: Optional[int] = None
|
||||||
|
|
||||||
|
n_layers: Optional[int] = None
|
||||||
|
norm_eps: Optional[float] = None
|
||||||
|
dim_ffn: Optional[int] = None
|
||||||
|
tie_weight: Optional[bool] = None
|
||||||
|
|
||||||
|
# RoPE
|
||||||
|
max_len: Optional[int] = None
|
||||||
|
rope_theta: Optional[float] = None
|
||||||
|
|
||||||
|
# GQA
|
||||||
|
n_heads: Optional[int] = None
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
use_qk_norm: Optional[bool] = None
|
||||||
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, config_path: str) -> Self:
|
||||||
|
config = {}
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config.update(json.load(f))
|
||||||
|
|
||||||
|
for key, value in config.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def save(self, config_path: str):
|
||||||
|
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import safetensors.torch as st
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Self, Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
from khaosz.config.model_config import ModelConfig
|
||||||
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelIO:
|
||||||
|
"""Base class for model I/O operations."""
|
||||||
|
|
||||||
|
model: Optional[nn.Module] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Transformer model."}
|
||||||
|
)
|
||||||
|
tokenizer: BpeTokenizer = field(
|
||||||
|
default_factory=BpeTokenizer,
|
||||||
|
metadata={"help": "Tokenizer for the model."}
|
||||||
|
)
|
||||||
|
config: ModelConfig = field(
|
||||||
|
default_factory=ModelConfig,
|
||||||
|
metadata={"help": "Transformer model configuration."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
|
"""Get standardized file paths for model components."""
|
||||||
|
dir_path = Path(directory)
|
||||||
|
return {
|
||||||
|
"model": dir_path / "model.safetensors",
|
||||||
|
"config": dir_path / "config.json",
|
||||||
|
"tokenizer": dir_path / "tokenizer.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_components(self, save_dir: Union[str, Path]):
|
||||||
|
"""Save core model components."""
|
||||||
|
paths = self._get_file_paths(save_dir)
|
||||||
|
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
st.save_file(self.model.state_dict(), str(paths["model"]))
|
||||||
|
self.config.save(str(paths["config"]))
|
||||||
|
self.tokenizer.save(str(paths["tokenizer"]))
|
||||||
|
|
||||||
|
def load_components(self, load_dir: Union[str, Path]) -> Self:
|
||||||
|
"""Load core model components."""
|
||||||
|
paths = self._get_file_paths(load_dir)
|
||||||
|
|
||||||
|
self.config.load(str(paths["config"]))
|
||||||
|
self.tokenizer.load(str(paths["tokenizer"]))
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
self.model = Transformer(self.config)
|
||||||
|
|
||||||
|
if paths["model"].exists():
|
||||||
|
state_dict = st.load_file(str(paths["model"]))
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs) -> "BaseModelIO":
|
||||||
|
"""Move model to device."""
|
||||||
|
if self.model is not None:
|
||||||
|
self.model.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelParameter(BaseModelIO):
|
||||||
|
"""Container for model parameters with serialization capabilities."""
|
||||||
|
|
||||||
|
def save(self, save_dir: Union[str, Path]):
|
||||||
|
self.save_components(save_dir)
|
||||||
|
|
||||||
|
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||||
|
return self.load_components(load_dir)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
from typing import Any, Dict
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScheduleConfig(ABC):
|
||||||
|
schedule_type: str = field(
|
||||||
|
default="cosine",
|
||||||
|
metadata={
|
||||||
|
"help": "Type of learning rate schedule.",
|
||||||
|
"choices": ["cosine", "sgdr"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
warmup_steps: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Number of warmup steps."}
|
||||||
|
)
|
||||||
|
min_rate: float = field(
|
||||||
|
default=0.05,
|
||||||
|
metadata={"help": "Minimum learning rate multiplier."}
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
if self.warmup_steps < 0:
|
||||||
|
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
||||||
|
if not 0 <= self.min_rate <= 1:
|
||||||
|
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CosineScheduleConfig(ScheduleConfig):
|
||||||
|
total_steps: int = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Total training steps for cosine schedule."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.schedule_type = "cosine"
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
if self.total_steps is None:
|
||||||
|
raise ValueError("total_steps must be specified for cosine schedule")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"schedule_type": self.schedule_type,
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||||
|
"min_rate": self.min_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
super().validate()
|
||||||
|
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||||
|
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SGDRScheduleConfig(ScheduleConfig):
|
||||||
|
cycle_length: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Length of the first cycle in steps."}
|
||||||
|
)
|
||||||
|
t_mult: int = field(
|
||||||
|
default=2,
|
||||||
|
metadata={"help": "Multiplier for cycle length growth."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.schedule_type = "sgdr"
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"schedule_type": self.schedule_type,
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
"cycle_length": self.cycle_length,
|
||||||
|
"min_rate": self.min_rate,
|
||||||
|
"t_mult": self.t_mult
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
super().validate()
|
||||||
|
if self.cycle_length <= 0:
|
||||||
|
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
||||||
|
if self.t_mult < 1:
|
||||||
|
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig:
|
||||||
|
# basic setting
|
||||||
|
model: nn.Module = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Model for training."}
|
||||||
|
)
|
||||||
|
strategy: str = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Training strategy."}
|
||||||
|
)
|
||||||
|
dataset: Dataset = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Dataset for training."}
|
||||||
|
)
|
||||||
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Optimizer factory for training."}
|
||||||
|
)
|
||||||
|
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Scheduler factory for training."}
|
||||||
|
)
|
||||||
|
n_epoch: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of epochs for training."}
|
||||||
|
)
|
||||||
|
batch_size: int = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "Batch size for training."}
|
||||||
|
)
|
||||||
|
accumulation_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of iterations between steps."}
|
||||||
|
)
|
||||||
|
max_grad_norm: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Maximum gradient norm."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# checkpoint setting
|
||||||
|
start_epoch: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Start epoch for training."}
|
||||||
|
)
|
||||||
|
start_batch: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Start batch iteration for training."}
|
||||||
|
)
|
||||||
|
checkpoint_dir: str = field(
|
||||||
|
default="./checkpoint",
|
||||||
|
metadata={"help": "Checkpoint directory."}
|
||||||
|
)
|
||||||
|
checkpoint_interval: int = field(
|
||||||
|
default=5000,
|
||||||
|
metadata={"help": "Number of iterations between checkpoints."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader setting
|
||||||
|
random_seed: int = field(
|
||||||
|
default=3407,
|
||||||
|
metadata={"help": "Random seed."}
|
||||||
|
)
|
||||||
|
num_workers: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Number of workers for dataloader."}
|
||||||
|
)
|
||||||
|
prefetch_factor: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Prefetch factor for dataloader."}
|
||||||
|
)
|
||||||
|
pin_memory: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Pin memory for dataloader."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# distributed training
|
||||||
|
nprocs: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of processes for distributed training."}
|
||||||
|
)
|
||||||
|
backend: str = field(
|
||||||
|
default="nccl",
|
||||||
|
metadata={"help": "Distributed training backend."}
|
||||||
|
)
|
||||||
|
master_addr: str = field(
|
||||||
|
default="localhost",
|
||||||
|
metadata={"help": "Master address for distributed training."}
|
||||||
|
)
|
||||||
|
master_port: str = field(
|
||||||
|
default="29500",
|
||||||
|
metadata={"help": "Master port for distributed training."}
|
||||||
|
)
|
||||||
|
parallel_wrapper: Optional[Callable] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Parallel function for training."}
|
||||||
|
)
|
||||||
|
state_dict_fn: Optional[Callable] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Parallel function for state dict saving."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# others
|
||||||
|
device_ids: Optional[List[int]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Device ids for distributed training."}
|
||||||
|
)
|
||||||
|
device_type: str = field(
|
||||||
|
default="cuda",
|
||||||
|
metadata={"help": "Device type for distributed training."}
|
||||||
|
)
|
||||||
|
extra_kwargs: dict = field(
|
||||||
|
default_factory=dict,
|
||||||
|
metadata={"help": "Other arguments."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
|
||||||
|
|
||||||
|
for field_name in required_fields:
|
||||||
|
if getattr(self, field_name) is None:
|
||||||
|
raise ValueError(f"{field_name} is required.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
from khaosz.data.dataset import (
|
||||||
|
BaseDataset,
|
||||||
|
SeqDataset,
|
||||||
|
DpoDataset,
|
||||||
|
SftDataset,
|
||||||
|
PpoDataset,
|
||||||
|
MultiSegmentFetcher,
|
||||||
|
DatasetLoader
|
||||||
|
)
|
||||||
|
|
||||||
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
from khaosz.data.sampler import ResumableDistributedSampler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseDataset",
|
||||||
|
"SeqDataset",
|
||||||
|
"DpoDataset",
|
||||||
|
"SftDataset",
|
||||||
|
"PpoDataset",
|
||||||
|
"MultiSegmentFetcher",
|
||||||
|
"DatasetLoader",
|
||||||
|
"BpeTokenizer",
|
||||||
|
"ResumableDistributedSampler"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any
|
||||||
|
from khaosz.parallel.setup import get_rank
|
||||||
|
|
||||||
|
|
||||||
|
class Checkpoint:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
epoch: int = 0,
|
||||||
|
iteration: int = 0,
|
||||||
|
):
|
||||||
|
self.state_dict = state_dict
|
||||||
|
self.epoch = epoch
|
||||||
|
self.iteration = iteration
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
save_path = Path(save_dir)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
meta = {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"iteration": self.iteration,
|
||||||
|
}
|
||||||
|
with open(save_path / "meta.json", "w") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
with open(save_path / f"state_dict.pt", "wb") as f:
|
||||||
|
torch.save(self.state_dict, f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
save_dir: str,
|
||||||
|
) -> "Checkpoint":
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
|
meta = {}
|
||||||
|
if rank == 0:
|
||||||
|
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
meta_list = [meta]
|
||||||
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
|
meta = meta_list[0]
|
||||||
|
|
||||||
|
with open(save_path / f"state_dict.pt", "rb") as f:
|
||||||
|
state_dict = torch.load(f)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
state_dict=state_dict,
|
||||||
|
epoch=meta["epoch"],
|
||||||
|
iteration=meta["iteration"],
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
import torch
|
||||||
|
import bisect
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from khaosz.data.file import load_h5
|
||||||
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSegmentFetcher:
|
||||||
|
def __init__(self, segments: List[Tensor]):
|
||||||
|
self.segments = segments
|
||||||
|
self.cum_lengths = []
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for seg in segments:
|
||||||
|
total += torch.numel(seg)
|
||||||
|
self.cum_lengths.append(total)
|
||||||
|
|
||||||
|
self.total_length = total
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.total_length
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||||
|
raise ValueError("begin_idx or end_idx out of bounds")
|
||||||
|
if begin_idx >= end_idx:
|
||||||
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
|
# fix the range index bug
|
||||||
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||||
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||||
|
|
||||||
|
result_segments = []
|
||||||
|
|
||||||
|
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||||
|
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||||
|
start = max(begin_idx - prev_cum, 0)
|
||||||
|
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||||
|
data = self.segments[i][start:end]
|
||||||
|
result_segments.append(data)
|
||||||
|
|
||||||
|
return torch.cat(result_segments, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSegmentFetcher:
|
||||||
|
def __init__(self, muti_segments: Dict):
|
||||||
|
self.muti_keys = list(muti_segments.keys())
|
||||||
|
self.muti_fetchers = {
|
||||||
|
key: BaseSegmentFetcher(segments)
|
||||||
|
for key, segments in muti_segments.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||||
|
return min(len_list)
|
||||||
|
|
||||||
|
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||||
|
fetch_dict = {}
|
||||||
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
fetcher = self.muti_fetchers[key]
|
||||||
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||||
|
fetch_dict[key] = fetch_tensor
|
||||||
|
|
||||||
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||||
|
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDataset(Dataset, ABC):
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__()
|
||||||
|
self.segments = {}
|
||||||
|
self.window_size = window_size
|
||||||
|
self.stride = stride
|
||||||
|
self.total_samples = None
|
||||||
|
|
||||||
|
def load(self, load_path: str):
|
||||||
|
self.segments = load_h5(load_path)
|
||||||
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
self.total_samples = len(self.fetcher)
|
||||||
|
|
||||||
|
def get_index(self, index: int) -> int:
|
||||||
|
assert self.total_samples > self.window_size
|
||||||
|
|
||||||
|
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||||
|
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
||||||
|
|
||||||
|
return begin_idx, end_idx
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
assert self.total_samples is not None
|
||||||
|
if self.total_samples <= self.window_size:
|
||||||
|
return 0
|
||||||
|
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
|
class SeqDataset(BaseDataset):
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# fix the range index bug
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||||
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||||
|
|
||||||
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
||||||
|
class SftDataset(BaseDataset):
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||||
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
||||||
|
|
||||||
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
|
|
||||||
|
|
||||||
|
class DpoDataset(BaseDataset):
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||||
|
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||||
|
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
||||||
|
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
||||||
|
|
||||||
|
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||||
|
|
||||||
|
|
||||||
|
class PpoDataset(BaseDataset):
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
||||||
|
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
||||||
|
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
|
||||||
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
|
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetLoader:
|
||||||
|
@staticmethod
|
||||||
|
def load(
|
||||||
|
train_type: Literal["seq", "sft", "dpo"],
|
||||||
|
load_path: str,
|
||||||
|
window_size: int,
|
||||||
|
stride: Optional[int] = None,
|
||||||
|
) -> BaseDataset:
|
||||||
|
if stride is None:
|
||||||
|
stride = window_size
|
||||||
|
|
||||||
|
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||||
|
"seq": lambda window_size: SeqDataset(window_size, stride),
|
||||||
|
"sft": lambda window_size: SftDataset(window_size, stride),
|
||||||
|
"dpo": lambda window_size: DpoDataset(window_size, stride),
|
||||||
|
}
|
||||||
|
dataset = dataset_router[train_type](window_size)
|
||||||
|
dataset.load(load_path)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||||
|
with h5py.File(full_file_path, 'w') as f:
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
grp = f.create_group(key)
|
||||||
|
for idx, tensor in enumerate(tensors):
|
||||||
|
arr = tensor.cpu().numpy()
|
||||||
|
grp.create_dataset(f'data_{idx}', data=arr)
|
||||||
|
|
||||||
|
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
||||||
|
root_path = Path(file_path)
|
||||||
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||||
|
|
||||||
|
for h5_file in h5_files:
|
||||||
|
with h5py.File(h5_file, 'r') as f:
|
||||||
|
for key in f.keys():
|
||||||
|
grp = f[key]
|
||||||
|
dsets = []
|
||||||
|
for dset_name in grp.keys():
|
||||||
|
dset = grp[dset_name]
|
||||||
|
tensor = torch.from_numpy(dset[:])
|
||||||
|
if share_memory:
|
||||||
|
tensor = tensor.share_memory_()
|
||||||
|
dsets.append(tensor)
|
||||||
|
|
||||||
|
if tensor_group.get(key) is None:
|
||||||
|
tensor_group[key] = []
|
||||||
|
tensor_group[key].extend(dsets)
|
||||||
|
|
||||||
|
return tensor_group
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class ResumableDistributedSampler(Sampler[int]):
|
class ResumableDistributedSampler(Sampler[int]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_source: Dataset,
|
data_source: Dataset,
|
||||||
start_epoch: int = 0,
|
start_epoch: int=0,
|
||||||
start_iter: int = 0,
|
start_iter: int=0,
|
||||||
seed: int = 42,
|
seed: int=42,
|
||||||
drop_last: bool = False,
|
drop_last: bool=False,
|
||||||
shuffle: bool = True,
|
shuffle: bool=True,
|
||||||
process_group: Optional[dist.ProcessGroup] = None,
|
process_group: Optional[dist.ProcessGroup]=None,
|
||||||
):
|
):
|
||||||
self.epoch = start_epoch
|
self.epoch = start_epoch
|
||||||
self.iter = start_iter
|
self.iter = start_iter
|
||||||
|
|
@ -40,10 +40,9 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
self.iter = self.iter % self.num_samples_per_replica
|
|
||||||
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
|
@ -59,10 +58,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
padding_size = self.total_size - len(indices)
|
padding_size = self.total_size - len(indices)
|
||||||
indices += indices[:padding_size]
|
indices += indices[:padding_size]
|
||||||
|
|
||||||
local_indices = indices[self.rank : self.total_size : self.num_replicas]
|
local_indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||||
|
|
||||||
self.iter = self.iter % self.num_samples_per_replica
|
self.iter = self.iter % self.num_samples_per_replica
|
||||||
self._indices = local_indices[self.iter :]
|
self._indices = local_indices[self.iter:]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self._indices is None:
|
if self._indices is None:
|
||||||
|
|
@ -75,10 +74,5 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
@property
|
|
||||||
def _remaining(self):
|
|
||||||
remaining = self.num_samples_per_replica - self.iter
|
|
||||||
return max(remaining, 0)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._remaining
|
return self.num_samples_per_replica
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
from tokenizers import Tokenizer, Encoding
|
||||||
|
from tokenizers import decoders, processors, normalizers, pre_tokenizers
|
||||||
|
from tokenizers.models import BPE
|
||||||
|
from tokenizers.trainers import BpeTrainer
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class BpeTokenizer:
|
||||||
|
def __init__(self, path=None):
|
||||||
|
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
||||||
|
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
|
||||||
|
|
||||||
|
model = BPE()
|
||||||
|
self._tokenizer = Tokenizer(model)
|
||||||
|
self._tokenizer.normalizer = normalizers.Sequence([
|
||||||
|
normalizers.NFC(),
|
||||||
|
normalizers.Strip()
|
||||||
|
])
|
||||||
|
|
||||||
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||||
|
pre_tokenizers.UnicodeScripts(),
|
||||||
|
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)
|
||||||
|
])
|
||||||
|
|
||||||
|
self._tokenizer.decoder = decoders.ByteLevel()
|
||||||
|
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
||||||
|
|
||||||
|
if path is not None:
|
||||||
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
|
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple:
|
||||||
|
assert reserved_token_size > len(self._special_tokens)
|
||||||
|
reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
|
||||||
|
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
|
||||||
|
|
||||||
|
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||||
|
min_size = len(alphabet) + len(self._control_tokens)
|
||||||
|
assert detail_vocab_size > min_size
|
||||||
|
|
||||||
|
trainer = BpeTrainer(
|
||||||
|
vocab_size=detail_vocab_size,
|
||||||
|
min_frequency=min_freq,
|
||||||
|
limit_alphabet=detail_vocab_size // 6,
|
||||||
|
max_token_length=max_token_length,
|
||||||
|
special_tokens=self._control_tokens,
|
||||||
|
initial_alphabet=alphabet,
|
||||||
|
show_progress=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer, detail_vocab_size, reserved_tokens
|
||||||
|
|
||||||
|
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
|
||||||
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
min_freq=min_freq,
|
||||||
|
reserved_token_size=reserved_token_size
|
||||||
|
)
|
||||||
|
self._tokenizer.train(files=files, trainer=trainer)
|
||||||
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
||||||
|
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
|
||||||
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
min_freq=min_freq,
|
||||||
|
reserved_token_size=reserved_token_size
|
||||||
|
)
|
||||||
|
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
||||||
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
self._tokenizer.save(path)
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
|
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
|
||||||
|
return encoded.ids if out_ids else encoded.tokens
|
||||||
|
elif isinstance(tokens, list):
|
||||||
|
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
|
||||||
|
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
|
||||||
|
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._tokenizer.get_vocab_size()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_ids(self) -> List[int]:
|
||||||
|
stop_token = self._control_tokens + self._special_tokens
|
||||||
|
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token]
|
||||||
|
return stop_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_id(self) -> int:
|
||||||
|
return self._tokenizer.token_to_id("<bos>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self) -> int:
|
||||||
|
return self._tokenizer.token_to_id("<eos>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self) -> int:
|
||||||
|
return self._tokenizer.token_to_id("<pad>")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue