Compare commits
218 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
01ce1fb9e3 | |
|
|
14f83cbdac | |
|
|
dbe5891201 | |
|
|
2a65c3314c | |
|
|
1c2ff05a6d | |
|
|
31ae2deeba | |
|
|
69207e2c57 | |
|
|
138c5bcc08 | |
|
|
a923e0a23a | |
|
|
f521a30b22 | |
|
|
d4451f6afb | |
|
|
a3275423a4 | |
|
|
b37c3d000c | |
|
|
6031020e37 | |
|
|
c424dfc293 | |
|
|
3a28e52e98 | |
|
|
e371908b54 | |
|
|
7c99da155c | |
|
|
629e72385b | |
|
|
0a708fff24 | |
|
|
6e150ea6d0 | |
|
|
cb8dcb97ea | |
|
|
2d5dc93b3d | |
|
|
4145d35e3c | |
|
|
34c6c45bd6 | |
|
|
e9def84ce7 | |
|
|
836e02a166 | |
|
|
b558e61f63 | |
|
|
65ab69543b | |
|
|
1d26aa2e93 | |
|
|
a548d4553e | |
|
|
dd1b39f435 | |
|
|
94d6e713e9 | |
|
|
47c37e4876 | |
|
|
737585a32a | |
|
|
a4688021bf | |
|
|
7df6eb9211 | |
|
|
82a3f2626f | |
|
|
7fa69572c0 | |
|
|
3ab4f237e5 | |
|
|
8cbf3f36e2 | |
|
|
0594ce1017 | |
|
|
ff509ff39f | |
|
|
785d65436c | |
|
|
64be81b7b3 | |
|
|
45479b5731 | |
|
|
e0a3337c22 | |
|
|
812238060b | |
|
|
14b0d56197 | |
|
|
6c8533f1d2 | |
|
|
2c2697390d | |
|
|
7621f05d3f | |
|
|
10ebd7211f | |
|
|
42a391f0fb | |
|
|
97c7ac0f4f | |
|
|
8f1b32f2b6 | |
|
|
c241a5dcef | |
|
|
44dab27fdc | |
|
|
a44fd22a99 | |
|
|
8a11a7d444 | |
|
|
1d54491809 | |
|
|
ad9f4d9cf6 | |
|
|
e1638a7ade | |
|
|
f91bfee33e | |
|
|
d7a7f570ed | |
|
|
7dea929788 | |
|
|
026d1fc33d | |
|
|
7242eedbf4 | |
|
|
04c0dc7a47 | |
|
|
48a53121ba | |
|
|
0ba8c70ce1 | |
|
|
3d12a03909 | |
|
|
c169659611 | |
|
|
e12f1a7ee5 | |
|
|
ef25efffa2 | |
|
|
19532440b4 | |
|
|
9096e413c3 | |
|
|
9d5e9fa6c4 | |
|
|
08dde46778 | |
|
|
513f1f7826 | |
|
|
e3382f6bb5 | |
|
|
f0339022c1 | |
|
|
d8da2cf17c | |
|
|
205b40bd28 | |
|
|
18fe6e9339 | |
|
|
2196c34c52 | |
|
|
466c2e1efd | |
|
|
7e26d848ab | |
|
|
ed95ef245c | |
|
|
6d6ef99e66 | |
|
|
a8e2a1ba45 | |
|
|
6269bacfc3 | |
|
|
c0effc9f5b | |
|
|
df0845e916 | |
|
|
7440e9c809 | |
|
|
7d4029c2a4 | |
|
|
0ca6c9e6eb | |
|
|
6e49d27057 | |
|
|
5203b7f53e | |
|
|
5889179c54 | |
|
|
38e18fdfd3 | |
|
|
4753958f92 | |
|
|
73d6cc0f26 | |
|
|
317ed90bac | |
|
|
951df8155c | |
|
|
a58fab8d6e | |
|
|
a3c8296135 | |
|
|
c95ace41aa | |
|
|
3da428e0e4 | |
|
|
133a9de98f | |
|
|
523eacf5fe | |
|
|
cffedaad5e | |
|
|
3583c46b66 | |
|
|
ca4e6b907c | |
|
|
db99d8b254 | |
|
|
b98c9cefdc | |
|
|
283bcaf2ff | |
|
|
bc7c82977e | |
|
|
34a511e36e | |
|
|
d73f52a2f8 | |
|
|
9d96b0431d | |
|
|
f81e2b4a73 | |
|
|
4e324d8f26 | |
|
|
6ed0506491 | |
|
|
30cc2d67a4 | |
|
|
7ddebf2cd9 | |
|
|
78dc2bd41c | |
|
|
44d7a4e959 | |
|
|
c4401512f2 | |
|
|
a6f5ff3b37 | |
|
|
ffff05b2c6 | |
|
|
b89f8436ea | |
|
|
123f25e339 | |
|
|
520de3ebe8 | |
|
|
466c34d7a8 | |
|
|
6831a15424 | |
|
|
0f9e5c5049 | |
|
|
cb0e7f2a80 | |
|
|
296db909aa | |
|
|
a2ae742988 | |
|
|
29beb174a5 | |
|
|
bbeaff4c60 | |
|
|
ab5e207f42 | |
|
|
b0eff02446 | |
|
|
408f0cb513 | |
|
|
64b78ecce3 | |
|
|
f2ffdf60d0 | |
|
|
ace8f6ee68 | |
|
|
a57a16430d | |
|
|
3fee87897d | |
|
|
3f67e53088 | |
|
|
bf7adb35b3 | |
|
|
feaa3fca36 | |
|
|
39766aa1dc | |
|
|
9b22b1651e | |
|
|
e58dbd7c57 | |
|
|
d2fe8afbd1 | |
|
|
23ce4bc3ae | |
|
|
d2b36cc85d | |
|
|
fc278d17ab | |
|
|
ff43a2fab8 | |
|
|
2b26f03bd3 | |
|
|
861d33b1a1 | |
|
|
99b821ebf5 | |
|
|
c94a246c71 | |
|
|
2dc9545d7f | |
|
|
9c31d78a22 | |
|
|
bd9741dc5f | |
|
|
b531232a9b | |
|
|
3346c75584 | |
|
|
aa5e03d7f6 | |
|
|
073baf105c | |
|
|
e97536758f | |
|
|
7861af12e4 | |
|
|
7f0552013a | |
|
|
3535de5cc4 | |
|
|
26989e54aa | |
|
|
70d52935f0 | |
|
|
c0e0e6afd9 | |
|
|
0852b852f8 | |
|
|
3a7d98a950 | |
|
|
c5560740b6 | |
|
|
94c6a015c8 | |
|
|
8b6509b305 | |
|
|
912d7c7f54 | |
|
|
475de51c7d | |
|
|
9f1561afe7 | |
|
|
80c0b20877 | |
|
|
e7721eafc6 | |
|
|
4ead0a20cf | |
|
|
b1527d9575 | |
|
|
2e009cf59a | |
|
|
780b9e1855 | |
|
|
aef7615abd | |
|
|
50488bd659 | |
|
|
eb57e55fca | |
|
|
426af2d75f | |
|
|
345fd2f091 | |
|
|
e1f9901384 | |
|
|
0e7fc623b4 | |
|
|
3e33c14376 | |
|
|
60f4df95bd | |
|
|
c01791ff54 | |
|
|
980299cd54 | |
|
|
3e8f2eba81 | |
|
|
361cdeb296 | |
|
|
50f76cd7c7 | |
|
|
0f518473af | |
|
|
a5574f92e2 | |
|
|
abcedf892e | |
|
|
abc3a06266 | |
|
|
62fba9a298 | |
|
|
e23a5ca426 | |
|
|
e55b57d771 | |
|
|
c4feab96fe | |
|
|
e35cb0d84a | |
|
|
6d6ef6dbb6 | |
|
|
493fe4e84b |
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Ignore everything
|
||||||
|
*
|
||||||
|
|
||||||
|
# Allow necessary files
|
||||||
|
!astrai/
|
||||||
|
!scripts/
|
||||||
|
!assets/
|
||||||
|
!pyproject.toml
|
||||||
|
!README.md
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Auto detect text files
|
||||||
|
* text=auto
|
||||||
|
|
||||||
|
# Files that MUST use LF (Unix/Linux execution)
|
||||||
|
*.sh text eol=lf
|
||||||
|
*.py text eol=lf
|
||||||
|
*.md text eol=lf
|
||||||
|
*.yml text eol=lf
|
||||||
|
|
||||||
|
Dockerfile text eol=lf
|
||||||
|
.dockerignore text eol=lf
|
||||||
|
|
||||||
|
.gitignore text eol=lf
|
||||||
|
.gitattributes text eol=lf
|
||||||
|
|
||||||
|
# Windows scripts - use CRLF
|
||||||
|
*.bat text eol=crlf
|
||||||
|
*.cmd text eol=crlf
|
||||||
|
*.ps1 text eol=crlf
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
---
|
||||||
|
name: Bug report
|
||||||
|
about: Create a report to help us improve
|
||||||
|
title: "[BUG]"
|
||||||
|
labels: bug
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Description
|
||||||
|
A clear and concise description of what the bug is.
|
||||||
|
## Steps to Reproduce
|
||||||
|
1. ...
|
||||||
|
2. ...
|
||||||
|
3. ...
|
||||||
|
## Expected Behavior
|
||||||
|
What you expected to happen.
|
||||||
|
## Actual Behavior
|
||||||
|
What actually happened.
|
||||||
|
## Environment
|
||||||
|
- Python version:
|
||||||
|
- AstrAI version (or commit hash):
|
||||||
|
- Operating System:
|
||||||
|
- GPU (if applicable):
|
||||||
|
- CUDA/cuDNN version (if applicable):
|
||||||
|
## Additional Context
|
||||||
|
Add any other context, screenshots, or logs here.
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
---
|
||||||
|
name: Custom issue template
|
||||||
|
about: Describe this issue template's purpose here.
|
||||||
|
title: ''
|
||||||
|
labels: ''
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
---
|
||||||
|
name: Feature request
|
||||||
|
about: Suggest an idea for this project
|
||||||
|
title: "[FEAT]"
|
||||||
|
labels: ''
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Description
|
||||||
|
A clear and concise description of the feature you'd like to see.
|
||||||
|
## Problem Statement
|
||||||
|
What problem does this feature solve? Why is it needed?
|
||||||
|
## Proposed Solution
|
||||||
|
Describe the solution you'd like. Include any design ideas, API changes, or implementation details.
|
||||||
|
## Alternatives Considered
|
||||||
|
Describe any alternative solutions or features you've considered.
|
||||||
|
## Additional Context
|
||||||
|
Add any other context, screenshots, or references here.
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
## Description
|
||||||
|
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context.
|
||||||
|
|
||||||
|
Fixes # (issue number)
|
||||||
|
|
||||||
|
## Type of Change
|
||||||
|
Please delete options that are not relevant.
|
||||||
|
|
||||||
|
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||||
|
- [ ] New feature (non-breaking change which adds functionality)
|
||||||
|
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||||
|
- [ ] Documentation update
|
||||||
|
- [ ] Other (please describe):
|
||||||
|
|
||||||
|
## How Has This Been Tested?
|
||||||
|
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
|
||||||
|
|
||||||
|
## Checklist:
|
||||||
|
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
|
||||||
|
- [ ] I have performed a self-review of my own code
|
||||||
|
- [ ] Code is self-documenting (no unnecessary comments)
|
||||||
|
- [ ] I have made corresponding changes to the documentation
|
||||||
|
- [ ] My changes generate no new warnings
|
||||||
|
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||||
|
- [ ] New and existing unit tests pass locally with my changes
|
||||||
|
- [ ] Any dependent changes have been merged and published in downstream modules
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
name: Build and Push Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Login to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Extract metadata
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: ghcr.io/${{ github.repository }}
|
||||||
|
tags: |
|
||||||
|
type=ref,event=tag
|
||||||
|
type=raw,value=latest
|
||||||
|
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: linux/amd64
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
cache-from: type=gha
|
||||||
|
cache-to: type=gha,mode=max
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
name: Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python 3.12
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Check formatting with ruff
|
||||||
|
run: |
|
||||||
|
ruff format --check .
|
||||||
|
|
||||||
|
- name: Check import sorting
|
||||||
|
run: |
|
||||||
|
ruff check . --select I
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
name: Spell Check
|
|
||||||
on: [push, pull_request]
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
spellcheck:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Check spelling in specific files
|
|
||||||
uses: codespell-project/actions-codespell@v2
|
|
||||||
with:
|
|
||||||
check_filenames: true
|
|
||||||
only_warn: false
|
|
||||||
path: "**/*.{md, py}"
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
name: Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
run: |
|
||||||
|
python -m pytest tests/ -v
|
||||||
|
|
@ -6,7 +6,18 @@
|
||||||
|
|
||||||
# Allow specific file types and root files
|
# Allow specific file types and root files
|
||||||
!*.py
|
!*.py
|
||||||
!*.md
|
!*.sh
|
||||||
!*.png
|
|
||||||
!LICENSE
|
# Allow GitHub files
|
||||||
!pyproject.toml
|
!/.github/**
|
||||||
|
|
||||||
|
# Allow root files
|
||||||
|
!/.gitattributes
|
||||||
|
!/.dockerignore
|
||||||
|
!/Dockerfile
|
||||||
|
!/docker-compose.yml
|
||||||
|
!/assets/**
|
||||||
|
!/CONTRIBUTING.md
|
||||||
|
!/LICENSE
|
||||||
|
!/pyproject.toml
|
||||||
|
!/README.md
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Contributing to AstrAI
|
||||||
|
|
||||||
|
Thank you for your interest in contributing! This document provides step-by-step guidelines.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/your-username/AstrAI.git
|
||||||
|
cd AstrAI
|
||||||
|
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Before You Commit
|
||||||
|
|
||||||
|
Run the following checks **in order** — CI will reject if any fail.
|
||||||
|
|
||||||
|
### 1. Format
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff format .
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note**: `ruff format` may rename parameters (e.g. `mask` → `attn_mask`).
|
||||||
|
> Always review the diff after formatting.
|
||||||
|
|
||||||
|
### 2. Import sorting
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff check . --select I
|
||||||
|
```
|
||||||
|
|
||||||
|
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff check . --select I --fix .
|
||||||
|
ruff format . # re-format after fix
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Run tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -u -m pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
|
||||||
|
|
||||||
|
### 4. (Optional) Full pre-commit check
|
||||||
|
|
||||||
|
If you have Git Bash available:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/pre_commit.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs format check, import sort check, and tests in one go.
|
||||||
|
|
||||||
|
## Commit Style
|
||||||
|
|
||||||
|
```
|
||||||
|
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
|
||||||
|
|
||||||
|
- bullet point body (each ~60 chars)
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
|
||||||
|
- **Subject line** ends with no period.
|
||||||
|
- **Body** uses bullet points starting with `-`.
|
||||||
|
- No `(scope)` parentheses.
|
||||||
|
|
||||||
|
## Common Issues
|
||||||
|
|
||||||
|
| Problem | Cause | Fix |
|
||||||
|
|---------|-------|-----|
|
||||||
|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
|
||||||
|
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
|
||||||
|
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
|
||||||
|
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
|
||||||
|
|
||||||
|
## Submitting Changes
|
||||||
|
|
||||||
|
1. Fork the repo.
|
||||||
|
2. Create a feature branch: `git checkout -b feat/my-feature`
|
||||||
|
3. Make changes following the steps above.
|
||||||
|
4. Commit with the commit style above.
|
||||||
|
5. Push: `git push origin feat/my-feature`
|
||||||
|
6. Open a Pull Request against `main`.
|
||||||
|
|
||||||
|
## Code Review
|
||||||
|
|
||||||
|
- All PRs are reviewed. We may request changes.
|
||||||
|
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
|
||||||
|
- Ensure all tests pass.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
# AstrAI Dockerfile - Multi-stage Build (Optimized)
|
||||||
|
|
||||||
|
# Build stage - use base image with minimal build tools
|
||||||
|
FROM ubuntu:24.04 AS builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install Python 3.12 and minimal build dependencies
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
python3.12 \
|
||||||
|
python3.12-dev \
|
||||||
|
python3.12-venv \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create isolated virtual environment
|
||||||
|
RUN python3.12 -m venv --copies /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Copy source code and install (deps read from pyproject.toml)
|
||||||
|
COPY astrai/ ./astrai/
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir --upgrade pip \
|
||||||
|
&& pip install --no-cache-dir . \
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cu126
|
||||||
|
|
||||||
|
# Production stage
|
||||||
|
FROM ubuntu:24.04 AS production
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install Python 3.12 runtime and healthcheck dependency
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
python3.12 \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy virtual environment from builder
|
||||||
|
COPY --from=builder /opt/venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY astrai/ ./astrai/
|
||||||
|
COPY scripts/ ./scripts/
|
||||||
|
COPY assets/ ./assets/
|
||||||
|
COPY pyproject.toml .
|
||||||
|
COPY README.md .
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd -m astrai && chown -R astrai:astrai /app
|
||||||
|
USER astrai
|
||||||
|
|
||||||
|
ENV PYTHONUNBUFFERED=1 \
|
||||||
|
PYTHONDONTWRITEBYTECODE=1
|
||||||
453
README.md
453
README.md
|
|
@ -1,286 +1,255 @@
|
||||||

|
<div align="center">
|
||||||
|
|
||||||
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; text-align: center; font-size: 16px; font-weight: bold; margin-top: 50px;">
|
<img src="assets/images/logo.png" width="auto" alt="Logo">
|
||||||
|
<p>
|
||||||
<div>
|
<strong>A lightweight Transformer training & inference framework</strong>
|
||||||
<a href="#english" style="text-decoration: none; margin: 0 10px; color: blue;">English</a> |
|
</p>
|
||||||
<a href="#chinese" style="text-decoration: none; margin: 0 10px; color: blue;">中文</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<h1 style="margin: 20px 0 0 0; font-size: 2.5em; font-weight: bold;">KHAOSZ </h1>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<h2 id="english">English Version</h2>
|
<div align="center">
|
||||||
|
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||||
|
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
|
||||||
|
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
|
||||||
|
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
|
||||||
|
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
|
||||||
|
</div>
|
||||||
|
<br>
|
||||||
|
|
||||||
A training and inference framework for autoregressive Transformer language models.
|
<div align="center">
|
||||||
|
<a href="#english">English</a> •
|
||||||
|
<a href="assets/docs/README-zh-CN.md">中文</a> •
|
||||||
|
<a href="https://github.com/ViperEkura/AstrAI/issues">Issue Tracker</a> •
|
||||||
|
<a href="https://github.com/ViperEkura/AstrAI/discussions">Discussions</a> •
|
||||||
|
<a href="https://huggingface.co/ViperEk/">HuggingFace</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
**Model Download Options (choose one):**
|
<br>
|
||||||
|
|
||||||
1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) and check **Files and versions**
|
## 📖 Table of Contents
|
||||||
2. Run `scripts/download.py` to download model parameters
|
|
||||||
|
|
||||||
**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
- [Features](#features)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Documentation](#documentation)
|
||||||
|
- [Contributing](#contributing)
|
||||||
|
- [Community](#community)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
For training data sources, please refer to the **Model Card** section on the HuggingFace download page.
|
---
|
||||||
|
|
||||||
**License:** The code follows the GPL-3.0 license. Please provide attribution when using it.
|
<a id="english"></a>
|
||||||
|
## English
|
||||||
|
|
||||||
- **📊 Device Selection:** Uses CUDA for training by default
|
### Features
|
||||||
- **🌐 Performance Optimization:** Enable `dtype=torch.bfloat16` to accelerate training and reduce memory usage. Ensure your hardware supports this feature
|
|
||||||
- **🤖 Language Support:** The model supports training in Chinese and English. Since the BBPE tokenizer hasn't been trained on multilingual text, OOV (Out-of-Vocabulary) issues are minimal for Chinese and English, but may exist for other languages
|
|
||||||
|
|
||||||
|
- 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization.
|
||||||
|
- 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures.
|
||||||
|
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
||||||
|
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
||||||
|
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
||||||
|
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
|
||||||
|
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
|
||||||
|
|
||||||
### 📌 Training Guide
|
### Quick Start
|
||||||
|
|
||||||
To train this Transformer model, follow these steps:
|
#### Installation
|
||||||
|
|
||||||
**(1). Prepare the Dataset:**
|
|
||||||
|
|
||||||
Place the dataset in the specified root directory. This system uses the BBPE tokenizer for tokenization and requires training with pre-tokenized segments (stored as *.h5 format files).
|
|
||||||
|
|
||||||
**(2). Install Dependencies:**
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
git clone https://github.com/ViperEkura/AstrAI.git
|
||||||
|
cd AstrAI
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
**(3). Run the Training Script:**
|
For development dependencies:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py \
|
pip install -e ".[dev]"
|
||||||
--train_type=train_type[seq, sft, dpo] \
|
|
||||||
--data_root_path=/path/to/dataset \
|
|
||||||
--param_path=/path/to/param_path \
|
|
||||||
--n_epoch=5 \
|
|
||||||
--batch_size=8 \
|
|
||||||
--max_lr=2e-4 \
|
|
||||||
--checkpoint_interval=10000 \
|
|
||||||
--checkpoint_dir=checkpoints
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Parameter Explanation:**
|
#### Download Pre-trained Model
|
||||||
- `--train_type`: Training type (seq, sft, dpo)
|
|
||||||
- `--data_root_path`: Dataset root directory
|
|
||||||
- `--param_path`: Path to model training parameters
|
|
||||||
- `--n_epoch`: Total number of training epochs
|
|
||||||
- `--batch_size`: Batch size
|
|
||||||
- `--accumulation_steps`: Number of batches per training step
|
|
||||||
- `--warmup_steps`: Warmup steps
|
|
||||||
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
|
||||||
- `--checkpoint_interval`: Checkpoint saving interval
|
|
||||||
- `--checkpoint_dir`: Checkpoint saving directory
|
|
||||||
- `--resume_dir`: Resume training from specified path
|
|
||||||
|
|
||||||
|
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
|
||||||
|
|
||||||
### 👉 Usage Guide
|
|
||||||
|
|
||||||
**(1). Chat with the Model:**
|
|
||||||
|
|
||||||
Open `chat.py` or use the streaming/non-streaming interfaces:
|
|
||||||
|
|
||||||
**Streaming Output:**
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
model_dir = "your_model_parameter_dir"
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
history = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
query = input(">> ")
|
|
||||||
if query == "!exit":
|
|
||||||
break
|
|
||||||
|
|
||||||
response_size = 0
|
|
||||||
for response, history in model.stream_generate(
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
temperature=0.85,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
):
|
|
||||||
print(response[response_size:], end="")
|
|
||||||
response_size = len(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Non-streaming Output:**
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
model_dir = "your_model_parameter_dir"
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
history = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
query = input(">> ")
|
|
||||||
if query == "!exit":
|
|
||||||
break
|
|
||||||
|
|
||||||
response = model.generate(
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
temperature=0.85,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
**(2). Retrieval-Augmented Generation (RAG):**
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
model_dir = "your_model_parameter_dir"
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
retrieved_content = model.retrieve_generate(
|
|
||||||
query=query,
|
|
||||||
retrieve_top_k=5,
|
|
||||||
temperature=0.6,
|
|
||||||
top_k=30,
|
|
||||||
top_p=0.95
|
|
||||||
)
|
|
||||||
print(retrieved_content)
|
|
||||||
```
|
|
||||||
|
|
||||||
<h2 id="chinese">中文版本</h2>
|
|
||||||
这是一个支持基于自回归模式的 Transfomer 语言模型训练以及推理框架
|
|
||||||
|
|
||||||
**模型下载选项(任选其一):**
|
|
||||||
|
|
||||||
1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
|
|
||||||
2. 运行 `scripts/download.py` 下载模型参数
|
|
||||||
|
|
||||||
**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
|
||||||
|
|
||||||
训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
|
|
||||||
|
|
||||||
**许可证:** 代码遵循 GPL-3.0 协议,使用时请注明出处。
|
|
||||||
|
|
||||||
- **📊 设备选择:** 默认使用 CUDA 进行训练
|
|
||||||
- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
|
|
||||||
- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV(未登录词)问题较少,其他语言可能存在 OOV 问题
|
|
||||||
|
|
||||||
|
|
||||||
### 📌 训练指南
|
|
||||||
|
|
||||||
要训练该 Transformer 模型,请按照以下步骤操作:
|
|
||||||
|
|
||||||
**(1). 准备数据集:**
|
|
||||||
|
|
||||||
将数据集放置在指定的根目录下, 本系统采用 BBPE 分词器进行分词,并且要求使用已经经过分词的 token 分段训练(分段存储为 *.h5 格式)
|
|
||||||
|
|
||||||
**(2). 安装依赖:**
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e .
|
python scripts/demo/download.py
|
||||||
```
|
```
|
||||||
|
|
||||||
**(3). 运行训练脚本:**
|
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
|
||||||
|
|
||||||
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type=train_type[seq, sft, dpo] \
|
|
||||||
--data_root_path=/path/to/dataset \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path=/path/to/param_path \
|
--nprocs=4 \
|
||||||
--n_epoch=5 \
|
--parallel_mode=ddp \
|
||||||
--batch_size=8 \
|
--train_type=seq \
|
||||||
--max_lr=2e-4 \
|
--data_root_path=/path/to/dataset \
|
||||||
--checkpoint_interval=10000 \
|
--param_path=/path/to/model \
|
||||||
--checkpoint_dir=checkpoints
|
--batch_per_device=4 \
|
||||||
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.9 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
**参数说明:**
|
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
- `--train_type`: 训练类型(seq, sft, dpo)
|
|
||||||
- `--data_root_path`: 数据集根目录
|
|
||||||
- `--param_path`: 模型训练参数路径
|
|
||||||
- `--n_epoch`: 总训练轮数
|
|
||||||
- `--batch_size`: 批量大小
|
|
||||||
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
|
||||||
- `--warmup_steps`: 预热步数(warmup steps)
|
|
||||||
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
|
||||||
- `--checkpoint_interval`: 检查点保存间隔
|
|
||||||
- `--checkpoint_dir`: 检查点保存目录
|
|
||||||
- `--resume_dir`: 从指定路径恢复训练
|
|
||||||
|
|
||||||
|
#### Generate Text
|
||||||
|
|
||||||
|
```bash
|
||||||
### 👉 使用指南
|
python scripts/tools/generate.py \
|
||||||
|
--param_path /path/to/model \
|
||||||
**(1). 与模型对话:**
|
--input_json_file /path/to/input.jsonl \
|
||||||
|
--output_json_file /path/to/output.jsonl
|
||||||
打开 `chat.py` 或使用流式/非流式接口:
|
|
||||||
|
|
||||||
**流式输出:**
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
model_dir = "your_model_parameter_dir"
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
history = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
query = input(">> ")
|
|
||||||
if query == "!exit":
|
|
||||||
break
|
|
||||||
|
|
||||||
response_size = 0
|
|
||||||
for response, history in model.stream_generate(
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
temperature=0.85,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
):
|
|
||||||
print(response[response_size:], end="")
|
|
||||||
response_size = len(response)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**非流式输出:**
|
#### Docker
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
model_dir = "your_model_parameter_dir"
|
Build and run with Docker (recommended for GPU environments):
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
history = []
|
|
||||||
|
|
||||||
while True:
|
```bash
|
||||||
query = input(">> ")
|
# Build image
|
||||||
if query == "!exit":
|
docker build -t astrai:latest .
|
||||||
break
|
|
||||||
|
|
||||||
response = model.generate(
|
# Run with GPU support
|
||||||
query=query,
|
docker run --gpus all -it astrai:latest
|
||||||
history=history,
|
|
||||||
temperature=0.85,
|
# Run with specific GPUs
|
||||||
top_p=0.95,
|
docker run --gpus '"device=0,1"' -it astrai:latest
|
||||||
top_k=50
|
|
||||||
)
|
# Run inference server
|
||||||
print(response)
|
docker run --gpus all -p 8000:8000 astrai:latest \
|
||||||
|
python -m scripts.tools.server --port 8000 --device cuda
|
||||||
|
|
||||||
|
# Run with volume mount for data
|
||||||
|
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
||||||
|
|
||||||
|
# Docker Compose (GPU, default)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Docker Compose (CPU only)
|
||||||
|
docker compose --profile cpu up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
**(2). 基于检索的生成(RAG):**
|
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
||||||
|
|
||||||
```python
|
#### Start HTTP Server
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
model_dir = "your_model_parameter_dir"
|
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
retrieved_content = model.retrieve_generate(
|
```bash
|
||||||
query=query,
|
python -m scripts.tools.server --port 8000 --device cuda
|
||||||
retrieve_top_k=5,
|
|
||||||
temperature=0.6,
|
|
||||||
top_k=30,
|
|
||||||
top_p=0.95
|
|
||||||
)
|
|
||||||
print(retrieved_content)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Make requests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# OpenAI-compatible
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 512
|
||||||
|
}'
|
||||||
|
|
||||||
|
# OpenAI-compatible streaming
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [{"role": "user", "content": "Tell a story"}],
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 500
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Anthropic-compatible
|
||||||
|
curl -X POST http://localhost:8000/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "astrai",
|
||||||
|
"system": "You are a helpful assistant.",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 512
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Anthropic-compatible streaming with stop sequences
|
||||||
|
curl -X POST http://localhost:8000/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "astrai",
|
||||||
|
"messages": [{"role": "user", "content": "Write a story"}],
|
||||||
|
"max_tokens": 500,
|
||||||
|
"stream": true,
|
||||||
|
"stop_sequences": ["The end"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
curl http://localhost:8000/health
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Demo
|
||||||
|
|
||||||
|
Check out the demos in the `scripts/demo/` folder:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download pre‑processed data (required before running demos)
|
||||||
|
python scripts/demo/download.py
|
||||||
|
|
||||||
|
# Interactive streaming chat
|
||||||
|
python scripts/demo/stream_chat.py
|
||||||
|
|
||||||
|
# Batch generation
|
||||||
|
python scripts/demo/generate_batch.py
|
||||||
|
|
||||||
|
# Auto‑regressive generation
|
||||||
|
python scripts/demo/generate_ar.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
| Document | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
|
||||||
|
| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns |
|
||||||
|
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
||||||
|
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
||||||
|
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
||||||
|
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
|
||||||
|
|
||||||
|
### Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
|
||||||
|
|
||||||
|
1. Fork the repository.
|
||||||
|
2. Create a feature branch.
|
||||||
|
3. Commit your changes.
|
||||||
|
4. Open a Pull Request.
|
||||||
|
|
||||||
|
For major changes, please open an issue first to discuss what you would like to change.
|
||||||
|
|
||||||
|
### Community
|
||||||
|
|
||||||
|
- **GitHub Issues**: [Issue Tracker](https://github.com/ViperEkura/AstrAI/issues)
|
||||||
|
- **Discussions**: [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions)
|
||||||
|
- **HuggingFace**: [Model Hub](https://huggingface.co/ViperEk)
|
||||||
|
|
||||||
|
### License
|
||||||
|
|
||||||
|
This project is licensed under the [GPL-3.0 License](LICENSE).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<em>A lightweight Transformer framework designed for both high performance and ease of use.</em>
|
||||||
|
</div>
|
||||||
|
|
@ -0,0 +1,261 @@
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
<img src="../images/logo.png" width="auto" alt="Logo">
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<a href="../../README.md">English</a> •
|
||||||
|
<a href="#chinese">中文</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
<strong>轻量级 Transformer 训练与推理框架</strong>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||||
|
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
|
||||||
|
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
|
||||||
|
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
|
||||||
|
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="../../README.md">English</a> •
|
||||||
|
<a href="#chinese">中文</a> •
|
||||||
|
<a href="https://github.com/ViperEkura/AstrAI/issues">问题追踪</a> •
|
||||||
|
<a href="https://github.com/ViperEkura/AstrAI/discussions">讨论区</a> •
|
||||||
|
<a href="https://huggingface.co/ViperEk">HuggingFace</a>
|
||||||
|
</div>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## 📖 目录
|
||||||
|
|
||||||
|
- [特性](#特性)
|
||||||
|
- [快速开始](#快速开始)
|
||||||
|
- [文档](#文档)
|
||||||
|
- [贡献](#贡献)
|
||||||
|
- [社区](#社区)
|
||||||
|
- [许可证](#许可证)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<a id="chinese"></a>
|
||||||
|
## 中文
|
||||||
|
|
||||||
|
### 特性
|
||||||
|
|
||||||
|
- 🚀 **高性能**: 训练与推理双向优化,高效并行。
|
||||||
|
- 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。
|
||||||
|
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
||||||
|
- 📦 **轻量**: 依赖少,部署简单。
|
||||||
|
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
||||||
|
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
|
||||||
|
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
|
||||||
|
|
||||||
|
### 快速开始
|
||||||
|
|
||||||
|
#### 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ViperEkura/AstrAI.git
|
||||||
|
cd AstrAI
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
安装开发依赖:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 下载预训练模型
|
||||||
|
|
||||||
|
下载预训练模型权重(1B 双语检查点)到 `params/` 目录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/demo/download.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。
|
||||||
|
|
||||||
|
#### 训练模型
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
|
nohup python scripts/tools/train.py \
|
||||||
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
|
--train_type=seq \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/model \
|
||||||
|
--batch_per_device=4 \
|
||||||
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.9 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
|
```
|
||||||
|
|
||||||
|
完整参数列表见[参数说明](./params.md)。
|
||||||
|
|
||||||
|
#### 文本生成
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/tools/generate.py \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--input_json_file /path/to/input.jsonl \
|
||||||
|
--output_json_file /path/to/output.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Docker
|
||||||
|
|
||||||
|
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 构建镜像
|
||||||
|
docker build -t astrai:latest .
|
||||||
|
|
||||||
|
# 启用 GPU 运行
|
||||||
|
docker run --gpus all -it astrai:latest
|
||||||
|
|
||||||
|
# 指定特定 GPU
|
||||||
|
docker run --gpus '"device=0,1"' -it astrai:latest
|
||||||
|
|
||||||
|
# 运行推理服务
|
||||||
|
docker run --gpus all -p 8000:8000 astrai:latest \
|
||||||
|
python -m scripts.tools.server --port 8000 --device cuda
|
||||||
|
|
||||||
|
# 挂载数据卷
|
||||||
|
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
||||||
|
|
||||||
|
# Docker Compose(GPU,默认)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Docker Compose(仅 CPU)
|
||||||
|
docker compose --profile cpu up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`。
|
||||||
|
|
||||||
|
#### 启动 HTTP 服务
|
||||||
|
|
||||||
|
启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m scripts.tools.server --port 8000 --device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
发起请求:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# OpenAI 兼容
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [{"role": "user", "content": "你好"}],
|
||||||
|
"max_tokens": 512
|
||||||
|
}'
|
||||||
|
|
||||||
|
# OpenAI 兼容流式
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [{"role": "user", "content": "讲个故事"}],
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 500
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Anthropic 兼容
|
||||||
|
curl -X POST http://localhost:8000/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "astrai",
|
||||||
|
"system": "你是一个乐于助人的助手。",
|
||||||
|
"messages": [{"role": "user", "content": "你好"}],
|
||||||
|
"max_tokens": 512
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Anthropic 兼容流式并设置停止序列
|
||||||
|
curl -X POST http://localhost:8000/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "astrai",
|
||||||
|
"messages": [{"role": "user", "content": "写个故事"}],
|
||||||
|
"max_tokens": 500,
|
||||||
|
"stream": true,
|
||||||
|
"stop_sequences": ["结束"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
# 健康检查
|
||||||
|
curl http://localhost:8000/health
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 演示
|
||||||
|
|
||||||
|
查看 `scripts/demo/` 文件夹中的演示:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载预处理数据(运行演示前必需)
|
||||||
|
python scripts/demo/download.py
|
||||||
|
|
||||||
|
# 交互式流式聊天
|
||||||
|
python scripts/demo/stream_chat.py
|
||||||
|
|
||||||
|
# 批量生成
|
||||||
|
python scripts/demo/generate_batch.py
|
||||||
|
|
||||||
|
# 自回归生成
|
||||||
|
python scripts/demo/generate_ar.py
|
||||||
|
```
|
||||||
|
|
||||||
|
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
|
||||||
|
|
||||||
|
### 文档
|
||||||
|
|
||||||
|
| 文档 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| [参数说明](./params.md) | 训练与推理参数配置 |
|
||||||
|
| [架构文档](./architecture.md) | 系统架构、类图与设计模式 |
|
||||||
|
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
||||||
|
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
||||||
|
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
||||||
|
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
|
||||||
|
|
||||||
|
### 贡献
|
||||||
|
|
||||||
|
我们欢迎贡献!请参阅[贡献指南](../../CONTRIBUTING.md)了解详情。
|
||||||
|
|
||||||
|
1. Fork 本仓库。
|
||||||
|
2. 创建功能分支。
|
||||||
|
3. 提交更改。
|
||||||
|
4. 发起 Pull Request。
|
||||||
|
|
||||||
|
重大更改请先开 issue 讨论。
|
||||||
|
|
||||||
|
### 社区
|
||||||
|
|
||||||
|
- **GitHub Issues**: [问题追踪](https://github.com/ViperEkura/AstrAI/issues)
|
||||||
|
- **Discussions**: [GitHub 讨论区](https://github.com/ViperEkura/AstrAI/discussions)
|
||||||
|
- **HuggingFace**: [模型中心](https://huggingface.co/ViperEk)
|
||||||
|
|
||||||
|
### 许可证
|
||||||
|
|
||||||
|
本项目采用 [GPL-3.0 许可证](../../LICENSE)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<em>专为高性能与易用性设计的轻量级 Transformer 框架。</em>
|
||||||
|
</div>
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Data Flow
|
||||||
|
|
||||||
|
This document describes the data pipeline: from raw text to model input tensors.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
```
|
||||||
|
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data Preparation
|
||||||
|
|
||||||
|
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||||
|
|
||||||
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
|
|
||||||
|
```
|
||||||
|
StoreFactory.create("h5") → H5Store
|
||||||
|
StoreFactory.create("bin") → MmapStore
|
||||||
|
```
|
||||||
|
|
||||||
|
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||||
|
|
||||||
|
## Data Keys by Training Type
|
||||||
|
|
||||||
|
| Type | Storage Keys |
|
||||||
|
|------|-------------|
|
||||||
|
| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) |
|
||||||
|
| `sft` | `sequence`, `loss_mask` |
|
||||||
|
| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` |
|
||||||
|
| `grpo` | `prompts`, `responses`, `masks`, `rewards` |
|
||||||
|
|
||||||
|
## Dataset Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
|
||||||
|
→ BaseDataset.load(load_path, storage_type=None)
|
||||||
|
→ detect_format(load_path)
|
||||||
|
→ StoreFactory.create(storage_type)
|
||||||
|
→ Store.load(load_path)
|
||||||
|
→ H5Store._normalize() / MmapStore._normalize()
|
||||||
|
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||||
|
→ BaseDataset.__getitem__(idx)
|
||||||
|
→ get_index(idx) → [begin, end)
|
||||||
|
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
|
||||||
|
```
|
||||||
|
|
||||||
|
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
|
||||||
|
|
||||||
|
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
|
||||||
|
|
||||||
|
## Sampler
|
||||||
|
|
||||||
|
`ResumableDistributedSampler` supports checkpoint-aware distributed sampling:
|
||||||
|
|
||||||
|
- Tracks `start_epoch` / `start_iter` for resume
|
||||||
|
- Shuffle via `torch.Generator(seed + epoch)`
|
||||||
|
- Per-replica index slicing for DDP
|
||||||
|
|
||||||
|
## DataLoader
|
||||||
|
|
||||||
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-30
|
||||||
|
|
@ -0,0 +1,152 @@
|
||||||
|
# Inference
|
||||||
|
|
||||||
|
## KV Cache
|
||||||
|
|
||||||
|
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
|
||||||
|
|
||||||
|
$$
|
||||||
|
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
|
||||||
|
$$
|
||||||
|
|
||||||
|
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
|
||||||
|
|
||||||
|
## KVCache System
|
||||||
|
|
||||||
|
Six classes (plus two helpers) working together:
|
||||||
|
|
||||||
|
```
|
||||||
|
KVCache (facade)
|
||||||
|
├── PagePool orchestrates page allocation + prefix matching
|
||||||
|
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
||||||
|
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
||||||
|
├── TaskTable maps task_id → page_table + cached token count
|
||||||
|
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
||||||
|
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
||||||
|
```
|
||||||
|
|
||||||
|
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
||||||
|
|
||||||
|
## Continuous Batching
|
||||||
|
|
||||||
|
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Cleanup → Remove finished tasks, free KV pages
|
||||||
|
2. Refill → Pop from waiting_queue, task_alloc pages, activate
|
||||||
|
3. Prefill → Group by (prompt_len, start_pos), run full forward
|
||||||
|
4. Decode → Pick largest same-position group, single-token forward
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sampling (Strategy Pattern)
|
||||||
|
|
||||||
|
```
|
||||||
|
BaseSamplingStrategy (ABC)
|
||||||
|
├── TemperatureStrategy
|
||||||
|
├── TopKStrategy
|
||||||
|
├── TopPStrategy
|
||||||
|
└── SamplingPipeline
|
||||||
|
```
|
||||||
|
|
||||||
|
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||||
|
`sample()` is a convenience shortcut for one-shot usage.
|
||||||
|
|
||||||
|
## Protocol Handlers (Strategy Pattern)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ProtocolHandler: # concrete orchestrator
|
||||||
|
def __init__(self, request, engine, builder): ...
|
||||||
|
async def handle(self):
|
||||||
|
prompt, ctx, stops = builder.prepare(request, engine)
|
||||||
|
agen = engine.generate_async(prompt, ...)
|
||||||
|
if stream: self._handle_stream(agen, ctx, stops)
|
||||||
|
else: return await self._handle_non_stream(agen, ctx, stops)
|
||||||
|
```
|
||||||
|
|
||||||
|
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||||
|
|
||||||
|
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
||||||
|
|
||||||
|
Adding a protocol = one builder file, no handler subclassing needed.
|
||||||
|
|
||||||
|
## Engine & GenerateResult
|
||||||
|
|
||||||
|
```
|
||||||
|
InferenceEngine
|
||||||
|
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
||||||
|
├── generate_with_request(req) → same
|
||||||
|
├── generate_async(prompt, ...) → AsyncGenerator
|
||||||
|
├── get_stats() → Dict
|
||||||
|
└── shutdown()
|
||||||
|
```
|
||||||
|
|
||||||
|
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
||||||
|
|
||||||
|
## HTTP API
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /v1/chat/completions OpenAI
|
||||||
|
POST /v1/messages Anthropic
|
||||||
|
GET /health {"status":"ok","model_loaded":true}
|
||||||
|
GET /stats scheduler statistics
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-abc123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1717000000,
|
||||||
|
"model": "astrai",
|
||||||
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
||||||
|
|
||||||
|
### Anthropic
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||||
|
|
||||||
|
### GenerationRequest Parameters
|
||||||
|
|
||||||
|
| Param | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||||
|
| `top_k` | int | 50 | Top-k count |
|
||||||
|
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||||
|
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
|
||||||
|
| `max_tokens` | Optional[int] | None | Max generation length |
|
||||||
|
| `stream` | bool | False | Stream output |
|
||||||
|
|
||||||
|
## Engine API
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Non-streaming
|
||||||
|
engine.generate("Hello", stream=False) # -> str
|
||||||
|
engine.generate(["A", "B"], stream=False) # -> List[str]
|
||||||
|
|
||||||
|
# Streaming
|
||||||
|
engine.generate("Hello", stream=True) # -> Generator[str]
|
||||||
|
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
|
|
||||||
|
# Async
|
||||||
|
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
|
||||||
|
print(token)
|
||||||
|
```
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-30
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
## 模型介绍
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### 1. 模型搭建
|
|
||||||
|
|
||||||
本模型采用Transformer架构, 使用GQA(q_head=24, kv_head=4) 机制,相较于传统的MHA可以节省KV cache 的显存占用(但是目前没有做KV cache),通过堆叠24层Transformer实现模型的搭建, 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文(即已经出现的tokens序列),计算出下一个可能的token及其对应的概率。
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### 1. 自回归
|
|
||||||
|
|
||||||
假设我们有一个句子被拆分成如下tokens列表:
|
|
||||||
|
|
||||||
```
|
|
||||||
["你好", "," "今天", "天气"]
|
|
||||||
```
|
|
||||||
|
|
||||||
接下来,模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出,比如:
|
|
||||||
|
|
||||||
```
|
|
||||||
-> {"token": "不错", "probability": 0.4}
|
|
||||||
-> {"token": "晴朗", "probability": 0.2}
|
|
||||||
-> ......
|
|
||||||
```
|
|
||||||
|
|
||||||
这里,“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens,并且给出了每个token成为下一个token的可能性大小。
|
|
||||||
|
|
||||||
之后,我们通过采样(通过top_k, top_p, temperature参数调整采样后的结果)得到下一个token并且将下一个token加入序列作为输入
|
|
||||||
|
|
||||||
```
|
|
||||||
["你好", "," "今天", "天气", "不错"]
|
|
||||||
```
|
|
||||||
|
|
||||||
之后都是在重复这个流程, 直到遇到控制流程结束的token(<|end_of_seqence|>)模型停止处理(一般模型都会设置控制token, 不然模型会一直输出到显存爆炸)。
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### 2. 因果掩码
|
|
||||||
|
|
||||||
transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len], 输出为[bsz, seq_len,n_dim], 为了实现预测下一个token, 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置, 在训练的时候我们也采用错开一个位置的方法
|
|
||||||
|
|
||||||
```
|
|
||||||
sequence : [[1, 2, 3, 4, 5, 6]]
|
|
||||||
input_ids: [[1, 2, 3, 4, 5]]
|
|
||||||
target_ids: [[2, 3, 4, 5, 6]]
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
注意力得分计算的公式为
|
|
||||||
|
|
||||||
|
|
||||||
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
|
|
||||||
$$ s_{ij} := s_{ij} + mask_{ij} $$
|
|
||||||
|
|
||||||
|
|
||||||
其中注意力得分代表了模型对两个token之间相似程度的关注程度
|
|
||||||
|
|
||||||
对于decoder only结构的模型, 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码,我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵,对于长度为n的序列,它的形状是[n, n]。下面以一个长度为5的序列为例,展示如何创建这样的因果掩码矩阵:
|
|
||||||
|
|
||||||
```
|
|
||||||
[[0, -inf, -inf, -inf, -inf],
|
|
||||||
[0, 0, -inf, -inf, -inf],
|
|
||||||
[0, 0, 0, -inf, -inf],
|
|
||||||
[0, 0, 0, 0, -inf],
|
|
||||||
[0, 0, 0, 0, 0]]
|
|
||||||
```
|
|
||||||
|
|
||||||
在这个矩阵中,0表示可以注意到的位置,而-inf表示应该被掩盖(即不应注意到)的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0, 也就是模型不能看到未来的信息
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### 3. 旋转位置编码
|
|
||||||
|
|
||||||
旋转位置编码(Rotary Position Embedding, RoPE)是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码(如正弦和余弦函数的位置编码)不同,RoPE通过将位置信息直接嵌入到查询(Query, Q)和键(Key, K)向量中来实现,使得模型能够更自然地处理序列中的相对位置关系。
|
|
||||||
|
|
||||||
|
|
||||||
$$ q_i = R_i W_q x_i $$
|
|
||||||
$$ k_j = R_j W_k x_j $$
|
|
||||||
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
|
|
||||||
|
|
||||||
其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
## kv_cache 实现
|
|
||||||
|
|
||||||
根据注意力的计算公式
|
|
||||||
|
|
||||||
$$
|
|
||||||
\begin{align*}
|
|
||||||
o_i &= \sum_j s_{ij} v_{j} \newline
|
|
||||||
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
|
|
||||||
\end{align*}
|
|
||||||
$$
|
|
||||||
|
|
||||||
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
|
|
||||||
|
|
||||||
$$
|
|
||||||
\begin{align*}
|
|
||||||
o_n &= \sum_j s_{j}v_{j} \newline
|
|
||||||
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
|
|
||||||
\end{align*}
|
|
||||||
$$
|
|
||||||
|
|
||||||
如果我们把式子展开
|
|
||||||
|
|
||||||
$$
|
|
||||||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
|
||||||
$$
|
|
||||||
|
|
||||||
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
|
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Parameter Documentation
|
||||||
|
|
||||||
|
## Training Parameters
|
||||||
|
|
||||||
|
### Basic Parameters
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
||||||
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
|
| `--batch_per_device` | Batch size per device | 1 |
|
||||||
|
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||||
|
|
||||||
|
### Learning Rate Scheduling
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
||||||
|
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||||
|
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||||
|
|
||||||
|
### Optimizer (AdamW)
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||||
|
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||||
|
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||||
|
|
||||||
|
### Data Loading
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--window_size` | Max input sequence length | model config `max_len` |
|
||||||
|
| `--stride` | Stride for sliding window over sequences | None |
|
||||||
|
| `--random_seed` | Random seed for reproducibility | 3407 |
|
||||||
|
| `--num_workers` | DataLoader worker processes | 4 |
|
||||||
|
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
|
||||||
|
|
||||||
|
### Checkpoint & Resume
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
|
||||||
|
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
||||||
|
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
|
||||||
|
| `--start_batch` | Resume from batch iteration | 0 |
|
||||||
|
|
||||||
|
### Distributed Training
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||||
|
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
||||||
|
| `--device_type` | Device type | cuda |
|
||||||
|
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||||
|
|
||||||
|
### Strategy-specific
|
||||||
|
|
||||||
|
| Parameter | Description | Default | Used by |
|
||||||
|
|-----------|-------------|---------|---------|
|
||||||
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
||||||
|
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||||
|
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||||
|
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||||
|
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
||||||
|
|
||||||
|
### Usage Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
|
nohup python scripts/tools/train.py \
|
||||||
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
|
--train_type=seq \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/model \
|
||||||
|
--batch_per_device=4 \
|
||||||
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.9 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-24
|
||||||
|
|
@ -0,0 +1,283 @@
|
||||||
|
# Preprocessing Pipeline
|
||||||
|
|
||||||
|
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
|
||||||
|
|
||||||
|
## Philosophy
|
||||||
|
|
||||||
|
| Component | Responsibility |
|
||||||
|
|-----------|---------------|
|
||||||
|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
||||||
|
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
||||||
|
|
||||||
|
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### SFT Chat
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"input": {
|
||||||
|
"type": "chat",
|
||||||
|
"messages_key": "messages"
|
||||||
|
},
|
||||||
|
"mask": {
|
||||||
|
"system": "mask",
|
||||||
|
"user": "mask",
|
||||||
|
"assistant": "train"
|
||||||
|
},
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {
|
||||||
|
"max_seq_len": 2048,
|
||||||
|
"deduplicate": true
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"domain_key": "source",
|
||||||
|
"storage_format": "bin",
|
||||||
|
"max_tokens_per_shard": 100000000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
|
||||||
|
|
||||||
|
### Instruction Tuning
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"input": {
|
||||||
|
"type": "instruction",
|
||||||
|
"prompt_key": "instruction",
|
||||||
|
"response_key": "output"
|
||||||
|
},
|
||||||
|
"mask": {
|
||||||
|
"prompt": "mask",
|
||||||
|
"response": "train"
|
||||||
|
},
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {
|
||||||
|
"max_seq_len": 2048
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"storage_format": "bin"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Mask splits at the prompt/response field boundary.
|
||||||
|
|
||||||
|
### Pretraining
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"input": {
|
||||||
|
"type": "text",
|
||||||
|
"text_key": "content"
|
||||||
|
},
|
||||||
|
"mask": {},
|
||||||
|
"preprocessing": {
|
||||||
|
"max_seq_len": 2048,
|
||||||
|
"min_chars": 50
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"storage_format": "bin"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
No mask -- train on all tokens.
|
||||||
|
|
||||||
|
### Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Reference
|
||||||
|
|
||||||
|
### `input`
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
|
||||||
|
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
|
||||||
|
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
|
||||||
|
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
|
||||||
|
| `text_key` | string | no | `"text"` | JSON key for text field |
|
||||||
|
|
||||||
|
### `mask`
|
||||||
|
|
||||||
|
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
|
||||||
|
|
||||||
|
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
|
||||||
|
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
|
||||||
|
|
||||||
|
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
|
||||||
|
For instruction mode, keys are `"prompt"` and `"response"`.
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `mask` | dict | `{}` | Role/field to action mapping |
|
||||||
|
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
|
||||||
|
|
||||||
|
### `preprocessing`
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
|
||||||
|
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
|
||||||
|
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
|
||||||
|
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
|
||||||
|
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
|
||||||
|
|
||||||
|
### `output`
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
|
||||||
|
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
|
||||||
|
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
|
||||||
|
|
||||||
|
## Mask Algorithm
|
||||||
|
|
||||||
|
### Chat Mode (role-span tracking)
|
||||||
|
|
||||||
|
For each message in the `messages` array:
|
||||||
|
|
||||||
|
1. Prepend BOS token (position 0, always masked)
|
||||||
|
2. Render through the chat template for that single message
|
||||||
|
3. Encode the rendered text, record token span `(start, end, role)`
|
||||||
|
4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
|
||||||
|
5. Fill `loss_mask` from the mask rules
|
||||||
|
|
||||||
|
**Multi-turn example**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Data:
|
||||||
|
[system: "You are helpful."]
|
||||||
|
[user: "What is 2+2?"]
|
||||||
|
[assistant: "4"]
|
||||||
|
[user: "What is 3+3?"]
|
||||||
|
[assistant: "6"]
|
||||||
|
|
||||||
|
Config:
|
||||||
|
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
|
||||||
|
|
||||||
|
Result:
|
||||||
|
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
|
||||||
|
mask: 0 0 0 1 0 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Both assistant turns are trained. All system and user tokens are masked.
|
||||||
|
|
||||||
|
### Instruction Mode (field boundary)
|
||||||
|
|
||||||
|
Encode the prompt and response fields independently, then split the mask at the field boundary.
|
||||||
|
|
||||||
|
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
|
||||||
|
- `"prompt": "train", "response": "mask"` -- the reverse
|
||||||
|
|
||||||
|
### Text Mode (no mask)
|
||||||
|
|
||||||
|
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
|
||||||
|
|
||||||
|
## Output Layout
|
||||||
|
|
||||||
|
### Single-Shard (`bin`)
|
||||||
|
|
||||||
|
```
|
||||||
|
output_dir/
|
||||||
|
__default__/ # when domain_key is null
|
||||||
|
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
|
||||||
|
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
|
||||||
|
loss_mask.bin # int64 raw bytes
|
||||||
|
wiki/ # when domain_key="source" and item["source"]="wiki"
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Shard (`bin`)
|
||||||
|
|
||||||
|
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
|
||||||
|
|
||||||
|
```
|
||||||
|
output_dir/
|
||||||
|
__default__/
|
||||||
|
shard_0000/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
shard_0001/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
`MmapStore` automatically discovers and merges all shards under the domain directory.
|
||||||
|
|
||||||
|
### H5 Output
|
||||||
|
|
||||||
|
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
|
||||||
|
|
||||||
|
```
|
||||||
|
output_dir/
|
||||||
|
__default__/
|
||||||
|
data_0000.h5 # each H5 contains key→dataset groups
|
||||||
|
data_0001.h5
|
||||||
|
wiki/
|
||||||
|
data_0000.h5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Python API Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline
|
||||||
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
|
|
||||||
|
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||||
|
Pipeline(
|
||||||
|
config,
|
||||||
|
["data_part1.jsonl", "data_part2.jsonl"],
|
||||||
|
output_dir="output/",
|
||||||
|
tokenizer_path="params"
|
||||||
|
).run()
|
||||||
|
```
|
||||||
|
|
||||||
|
Or from the CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Extension
|
||||||
|
|
||||||
|
Register a custom builder for new formats:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
|
||||||
|
|
||||||
|
@MaskBuilderFactory.register("my_format")
|
||||||
|
class MyFormatBuilder(BaseMaskBuilder):
|
||||||
|
def build(self, item: dict, config, tokenizer) -> dict | None:
|
||||||
|
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
|
||||||
|
# Return None to skip this item
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Then set `"input": {"type": "my_format"}` in your config.
|
||||||
|
|
||||||
|
## Compared to Old Pipeline
|
||||||
|
|
||||||
|
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|
||||||
|
|---|---|
|
||||||
|
| Configured via constructor arguments | Configured via JSON file |
|
||||||
|
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
|
||||||
|
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
|
||||||
|
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
|
||||||
|
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-30
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
# Training
|
||||||
|
|
||||||
|
### Autoregression
|
||||||
|
|
||||||
|
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||||||
|
|
||||||
|
### Causal Mask
|
||||||
|
|
||||||
|
```
|
||||||
|
sequence : [[1, 2, 3, 4, 5, 6]]
|
||||||
|
input_ids: [[1, 2, 3, 4, 5]]
|
||||||
|
target_ids: [[2, 3, 4, 5, 6]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Lower-triangular mask prevents attending to future positions:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[0, -inf, -inf, -inf, -inf],
|
||||||
|
[0, 0, -inf, -inf, -inf],
|
||||||
|
[0, 0, 0, -inf, -inf],
|
||||||
|
[0, 0, 0, 0, -inf],
|
||||||
|
[0, 0, 0, 0, 0]]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rotary Position Embedding (RoPE)
|
||||||
|
|
||||||
|
RoPE embeds position into Q/K vectors via complex rotation:
|
||||||
|
|
||||||
|
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||||||
|
|
||||||
|
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
|
||||||
|
|
||||||
|
## Training Loop
|
||||||
|
|
||||||
|
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
||||||
|
|
||||||
|
```
|
||||||
|
on_train_begin
|
||||||
|
model.train()
|
||||||
|
on_epoch_begin
|
||||||
|
for batch in dataloader:
|
||||||
|
on_batch_begin
|
||||||
|
with executor.accumulate(model):
|
||||||
|
loss = strategy.compute_loss(batch)
|
||||||
|
context.loss = loss.item()
|
||||||
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
|
executor.backward(stand_loss)
|
||||||
|
context.iteration += 1
|
||||||
|
on_batch_end
|
||||||
|
|
||||||
|
if executor.sync_gradients:
|
||||||
|
on_optimizer_step
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if scheduler:
|
||||||
|
scheduler.step()
|
||||||
|
on_epoch_end
|
||||||
|
on_train_end
|
||||||
|
```
|
||||||
|
|
||||||
|
### Callback Lifecycle
|
||||||
|
|
||||||
|
| Hook | Fires | Default callback |
|
||||||
|
|------|-------|-----------------|
|
||||||
|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||||
|
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
|
||||||
|
| `on_batch_begin` | Every batch | — |
|
||||||
|
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
||||||
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
|
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
|
||||||
|
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
|
||||||
|
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
|
||||||
|
|
||||||
|
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||||
|
|
||||||
|
## Strategies
|
||||||
|
|
||||||
|
### SEQ (Pre-training)
|
||||||
|
|
||||||
|
Next-token cross-entropy with optional label smoothing:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
||||||
|
$$
|
||||||
|
|
||||||
|
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
|
||||||
|
|
||||||
|
### SFT (Supervised Fine-Tuning)
|
||||||
|
|
||||||
|
Masked cross-entropy (`ignore_index=-100`) over response tokens:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
||||||
|
$$
|
||||||
|
|
||||||
|
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
|
||||||
|
|
||||||
|
### DPO (Direct Preference Optimization)
|
||||||
|
|
||||||
|
Frozen reference model, preference margin via log-ratio:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
||||||
|
$$
|
||||||
|
|
||||||
|
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||||||
|
|
||||||
|
### GRPO (Group Relative Policy Optimization)
|
||||||
|
|
||||||
|
On-policy PPO with group-normalized advantages:
|
||||||
|
|
||||||
|
$$
|
||||||
|
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
||||||
|
$$
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
||||||
|
$$
|
||||||
|
|
||||||
|
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
|
||||||
|
|
||||||
|
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
|
|
||||||
|
## LR Schedulers
|
||||||
|
|
||||||
|
| Type | Class | Description |
|
||||||
|
|------|-------|-------------|
|
||||||
|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||||
|
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||||
|
|
||||||
|
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
||||||
|
|
||||||
|
## Gradient Checkpointing
|
||||||
|
|
||||||
|
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
||||||
|
```
|
||||||
|
|
||||||
|
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
||||||
|
|
||||||
|
## Checkpoint
|
||||||
|
|
||||||
|
```
|
||||||
|
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
||||||
|
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
||||||
|
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
|
||||||
|
```
|
||||||
|
|
||||||
|
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||||
|
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
||||||
|
|
||||||
|
## TrainContextBuilder (Builder Pattern)
|
||||||
|
|
||||||
|
```python
|
||||||
|
context = (
|
||||||
|
TrainContextBuilder(config)
|
||||||
|
.with_resume_dir(resume_dir)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
- Loads checkpoint weights if provided
|
||||||
|
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||||
|
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||||
|
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||||
|
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
|
||||||
|
|
||||||
|
## Training CLI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
|
nohup python scripts/tools/train.py \
|
||||||
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
|
--train_type=seq \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/model \
|
||||||
|
--batch_per_device=4 \
|
||||||
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.9 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=0.01 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.05 \
|
||||||
|
> out.log 2> err.log &
|
||||||
|
```
|
||||||
|
|
||||||
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-30
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 281 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 11 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 590 KiB |
|
|
@ -0,0 +1,34 @@
|
||||||
|
__version__ = "1.3.7"
|
||||||
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
|
from astrai.config import (
|
||||||
|
AutoRegressiveLMConfig,
|
||||||
|
EncoderConfig,
|
||||||
|
TrainConfig,
|
||||||
|
)
|
||||||
|
from astrai.dataset import DatasetFactory
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.inference import (
|
||||||
|
GenerationRequest,
|
||||||
|
InferenceEngine,
|
||||||
|
)
|
||||||
|
from astrai.model import AutoModel, AutoRegressiveLM
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AutoRegressiveLM",
|
||||||
|
"AutoRegressiveLMConfig",
|
||||||
|
"EncoderConfig",
|
||||||
|
"TrainConfig",
|
||||||
|
"DatasetFactory",
|
||||||
|
"AutoTokenizer",
|
||||||
|
"GenerationRequest",
|
||||||
|
"InferenceEngine",
|
||||||
|
"Trainer",
|
||||||
|
"CallbackFactory",
|
||||||
|
"StrategyFactory",
|
||||||
|
"SchedulerFactory",
|
||||||
|
"BaseFactory",
|
||||||
|
"AutoModel",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from astrai.config.model_config import (
|
||||||
|
AutoRegressiveLMConfig,
|
||||||
|
BaseModelConfig,
|
||||||
|
ConfigFactory,
|
||||||
|
EncoderConfig,
|
||||||
|
)
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
OutputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseModelConfig",
|
||||||
|
"AutoRegressiveLMConfig",
|
||||||
|
"EncoderConfig",
|
||||||
|
"ConfigFactory",
|
||||||
|
"TrainConfig",
|
||||||
|
"InputConfig",
|
||||||
|
"OutputConfig",
|
||||||
|
"PipelineConfig",
|
||||||
|
"ProcessingConfig",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
import json
|
||||||
|
from dataclasses import MISSING, dataclass, fields
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseConfig:
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
d = {}
|
||||||
|
for fld in fields(self):
|
||||||
|
v = getattr(self, fld.name)
|
||||||
|
if isinstance(v, (str, int, float, bool)):
|
||||||
|
d[fld.name] = v
|
||||||
|
elif v is None:
|
||||||
|
d[fld.name] = None
|
||||||
|
elif isinstance(v, (dict, list, tuple)):
|
||||||
|
try:
|
||||||
|
val = list(v) if isinstance(v, tuple) else v
|
||||||
|
json.dumps(val)
|
||||||
|
d[fld.name] = val
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
elif isinstance(v, BaseConfig):
|
||||||
|
d[fld.name] = v.to_dict()
|
||||||
|
elif hasattr(v, "__dataclass_fields__"):
|
||||||
|
sub = {}
|
||||||
|
for f in fields(v):
|
||||||
|
a = getattr(v, f.name)
|
||||||
|
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
||||||
|
d[fld.name] = sub
|
||||||
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: Dict[str, Any]) -> Self:
|
||||||
|
hints = get_type_hints(cls)
|
||||||
|
inst = cls.__new__(cls)
|
||||||
|
for fld in fields(cls):
|
||||||
|
if fld.name in d:
|
||||||
|
v = d[fld.name]
|
||||||
|
target = cls._unwrap_optional(hints.get(fld.name))
|
||||||
|
if target is not None:
|
||||||
|
try:
|
||||||
|
v = cls._coerce(v, target)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
object.__setattr__(inst, fld.name, v)
|
||||||
|
elif fld.default is not MISSING:
|
||||||
|
object.__setattr__(inst, fld.name, fld.default)
|
||||||
|
elif fld.default_factory is not MISSING:
|
||||||
|
object.__setattr__(inst, fld.name, fld.default_factory())
|
||||||
|
else:
|
||||||
|
object.__setattr__(inst, fld.name, None)
|
||||||
|
return inst
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_optional(tp) -> Optional[type]:
|
||||||
|
if tp is None:
|
||||||
|
return None
|
||||||
|
origin = getattr(tp, "__origin__", None)
|
||||||
|
if origin is not None:
|
||||||
|
args = getattr(tp, "__args__", ())
|
||||||
|
non_none = [a for a in args if a is not type(None)]
|
||||||
|
return non_none[0] if non_none else None
|
||||||
|
return tp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce(value: Any, target_type: type) -> Any:
|
||||||
|
if target_type is bool and isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if (
|
||||||
|
target_type is int
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return int(value)
|
||||||
|
if (
|
||||||
|
target_type is float
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return float(value)
|
||||||
|
if target_type is str and isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if isinstance(value, target_type):
|
||||||
|
return value
|
||||||
|
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
||||||
|
return target_type.from_dict(value)
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, path: Union[str, Path]) -> Self:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return cls.from_dict(json.load(f))
|
||||||
|
|
||||||
|
def to_json(self, path: Union[str, Path]):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Optional, Self
|
||||||
|
|
||||||
|
from astrai.config.base import BaseConfig
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigFactory(BaseFactory[BaseConfig]):
|
||||||
|
"""Factory that dispatches config classes by ``model_type``."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
||||||
|
model_type = raw.get("model_type") or "autoregressive_lm"
|
||||||
|
config_cls = cls.get_component_class(model_type)
|
||||||
|
return config_cls.from_dict(raw)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelConfig(BaseConfig):
|
||||||
|
"""Base config with ``model_type`` dispatch and file I/O."""
|
||||||
|
|
||||||
|
model_type: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, config_path: str) -> Self:
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
raw: Dict[str, Any] = json.load(f)
|
||||||
|
return cls.from_dict(raw)
|
||||||
|
|
||||||
|
def to_file(self, config_path: str):
|
||||||
|
d = self.to_dict()
|
||||||
|
config_dict = {k: v for k, v in d.items() if v is not None}
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ConfigFactory.register("autoregressive_lm")
|
||||||
|
class AutoRegressiveLMConfig(BaseModelConfig):
|
||||||
|
"""Configuration for autoregressive language model."""
|
||||||
|
|
||||||
|
vocab_size: Optional[int] = None
|
||||||
|
dim: Optional[int] = None
|
||||||
|
n_layers: Optional[int] = None
|
||||||
|
norm_eps: Optional[float] = None
|
||||||
|
dim_ffn: Optional[int] = None
|
||||||
|
tie_weight: Optional[bool] = None
|
||||||
|
|
||||||
|
max_len: Optional[int] = None
|
||||||
|
rope_theta: Optional[float] = None
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
|
||||||
|
attn_type: str = "gqa"
|
||||||
|
n_heads: Optional[int] = None
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
use_qk_norm: Optional[bool] = None
|
||||||
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
kv_lora_rank: Optional[int] = None
|
||||||
|
qk_nope_head_dim: Optional[int] = None
|
||||||
|
qk_rope_head_dim: Optional[int] = None
|
||||||
|
|
||||||
|
ffn_type: str = "mlp"
|
||||||
|
n_routed_experts: Optional[int] = None
|
||||||
|
n_shared_experts: Optional[int] = None
|
||||||
|
n_activated_experts: Optional[int] = None
|
||||||
|
topk_method: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ConfigFactory.register("embedding")
|
||||||
|
class EncoderConfig(BaseModelConfig):
|
||||||
|
"""Configuration for embedding encoder model."""
|
||||||
|
|
||||||
|
vocab_size: Optional[int] = None
|
||||||
|
dim: Optional[int] = None
|
||||||
|
n_layers: Optional[int] = None
|
||||||
|
norm_eps: Optional[float] = None
|
||||||
|
dim_ffn: Optional[int] = None
|
||||||
|
|
||||||
|
max_len: Optional[int] = None
|
||||||
|
rope_theta: Optional[float] = None
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
|
||||||
|
n_heads: Optional[int] = None
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
use_qk_norm: Optional[bool] = None
|
||||||
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
pooling_type: Optional[str] = None
|
||||||
|
normalize_embeddings: Optional[bool] = None
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""Pipeline configuration for JSONL preprocessing."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from astrai.config.base import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InputConfig(BaseConfig):
|
||||||
|
sections: Optional[List[Dict]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessingConfig(BaseConfig):
|
||||||
|
max_seq_len: int = 2048
|
||||||
|
min_chars: int = 50
|
||||||
|
max_chars: int = 2_000_000
|
||||||
|
max_items: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutputConfig(BaseConfig):
|
||||||
|
domain_key: Optional[str] = None
|
||||||
|
storage_format: str = "bin"
|
||||||
|
max_tokens_per_shard: int = 100_000_000
|
||||||
|
dtype: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineConfig(BaseConfig):
|
||||||
|
version: int = 1
|
||||||
|
input: InputConfig = field(default_factory=InputConfig)
|
||||||
|
mask: Dict[str, str] = field(default_factory=dict)
|
||||||
|
mask_default: str = "mask"
|
||||||
|
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
||||||
|
output: OutputConfig = field(default_factory=OutputConfig)
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from astrai.config.base import BaseConfig
|
||||||
|
from astrai.model.components.lora import LoRAConfig
|
||||||
|
|
||||||
|
|
||||||
|
def required(**kw):
|
||||||
|
return {"required": True, **kw}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig(BaseConfig):
|
||||||
|
# basic setting
|
||||||
|
model_fn: Callable[[], nn.Module] = field(
|
||||||
|
default=None, metadata=required(help="Model factory for training.")
|
||||||
|
)
|
||||||
|
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||||
|
dataset: Dataset = field(
|
||||||
|
default=None, metadata=required(help="Dataset for training.")
|
||||||
|
)
|
||||||
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||||
|
default=None, metadata=required(help="Optimizer factory for training.")
|
||||||
|
)
|
||||||
|
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||||
|
default=None, metadata=required(help="Scheduler factory for training.")
|
||||||
|
)
|
||||||
|
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||||
|
batch_per_device: int = field(
|
||||||
|
default=4, metadata={"help": "Batch size per device."}
|
||||||
|
)
|
||||||
|
grad_accum_steps: int = field(
|
||||||
|
default=1, metadata={"help": "Number of iterations between steps."}
|
||||||
|
)
|
||||||
|
max_grad_norm: float = field(
|
||||||
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||||
|
)
|
||||||
|
gradient_checkpointing_modules: list = field(
|
||||||
|
default_factory=list,
|
||||||
|
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# checkpoint setting
|
||||||
|
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||||
|
start_batch: int = field(
|
||||||
|
default=0, metadata={"help": "Start batch iteration for training."}
|
||||||
|
)
|
||||||
|
ckpt_dir: str = field(
|
||||||
|
default="./checkpoint", metadata={"help": "Checkpoint directory."}
|
||||||
|
)
|
||||||
|
ckpt_interval: int = field(
|
||||||
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# lora setting
|
||||||
|
lora: Optional[LoRAConfig] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "LoRA config. None means full fine-tuning."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# metric setting
|
||||||
|
log_dir: str = field(
|
||||||
|
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
||||||
|
)
|
||||||
|
log_interval: int = field(
|
||||||
|
default=100,
|
||||||
|
metadata={"help": "Number of batch iterations between metric logs."},
|
||||||
|
)
|
||||||
|
metrics: List[str] = field(
|
||||||
|
default_factory=lambda: ["loss", "lr"],
|
||||||
|
metadata={"help": "Metrics to record during training."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader setting
|
||||||
|
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||||
|
num_workers: int = field(
|
||||||
|
default=0, metadata={"help": "Number of workers for dataloader."}
|
||||||
|
)
|
||||||
|
prefetch_factor: Optional[int] = field(
|
||||||
|
default=None, metadata={"help": "Prefetch factor for dataloader."}
|
||||||
|
)
|
||||||
|
pin_memory: bool = field(
|
||||||
|
default=False, metadata={"help": "Pin memory for dataloader."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# distributed training
|
||||||
|
nprocs: int = field(
|
||||||
|
default=1, metadata={"help": "Number of processes for distributed training."}
|
||||||
|
)
|
||||||
|
backend: str = field(
|
||||||
|
default="nccl", metadata={"help": "Distributed training backend."}
|
||||||
|
)
|
||||||
|
master_addr: str = field(
|
||||||
|
default="localhost",
|
||||||
|
metadata={"help": "Master address for distributed training."},
|
||||||
|
)
|
||||||
|
master_port: str = field(
|
||||||
|
default="29500", metadata={"help": "Master port for distributed training."}
|
||||||
|
)
|
||||||
|
parallel_mode: str = field(
|
||||||
|
default="none",
|
||||||
|
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
||||||
|
)
|
||||||
|
start_method: str = field(
|
||||||
|
default="spawn",
|
||||||
|
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# others
|
||||||
|
device_type: str = field(
|
||||||
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||||
|
)
|
||||||
|
val_dataset: Optional[Dataset] = field(
|
||||||
|
default=None, metadata={"help": "Dataset for validation."}
|
||||||
|
)
|
||||||
|
val_step: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||||
|
)
|
||||||
|
|
||||||
|
executor_kwargs: dict = field(
|
||||||
|
default_factory=dict,
|
||||||
|
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||||
|
)
|
||||||
|
extra_kwargs: dict = field(
|
||||||
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
for fld in fields(self):
|
||||||
|
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
||||||
|
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from astrai.dataset.dataset import (
|
||||||
|
BaseDataset,
|
||||||
|
DatasetFactory,
|
||||||
|
)
|
||||||
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
|
from astrai.dataset.storage import (
|
||||||
|
H5Store,
|
||||||
|
MmapStore,
|
||||||
|
Store,
|
||||||
|
StoreFactory,
|
||||||
|
detect_format,
|
||||||
|
load_bin,
|
||||||
|
load_h5,
|
||||||
|
save_bin,
|
||||||
|
save_h5,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseDataset",
|
||||||
|
"DatasetFactory",
|
||||||
|
"Store",
|
||||||
|
"StoreFactory",
|
||||||
|
"H5Store",
|
||||||
|
"MmapStore",
|
||||||
|
"detect_format",
|
||||||
|
"save_h5",
|
||||||
|
"load_h5",
|
||||||
|
"save_bin",
|
||||||
|
"load_bin",
|
||||||
|
"ResumableDistributedSampler",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,308 @@
|
||||||
|
"""Dataset implementations with factory pattern for training."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from astrai.dataset.storage import (
|
||||||
|
Store,
|
||||||
|
StoreFactory,
|
||||||
|
detect_format,
|
||||||
|
)
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDataset(Dataset, ABC):
|
||||||
|
"""Abstract base class for all dataset types.
|
||||||
|
|
||||||
|
Implements common functionality for window-based data fetching.
|
||||||
|
Uses a storage abstraction for format-agnostic data loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
self.stride = stride
|
||||||
|
self.storage: Optional[Store] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
"""Return required storage keys for this dataset type.
|
||||||
|
|
||||||
|
Subclasses should override to specify expected keys.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _validate_keys(self):
|
||||||
|
if not self.required_keys:
|
||||||
|
return
|
||||||
|
actual_keys = set(self.storage.keys)
|
||||||
|
missing = [k for k in self.required_keys if k not in actual_keys]
|
||||||
|
if missing:
|
||||||
|
raise KeyError(
|
||||||
|
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
|
||||||
|
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
|
||||||
|
f"Missing: {missing}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, load_path: str, storage_type: Optional[str] = None):
|
||||||
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
|
Auto-detects the storage format if not specified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_path: Path to the data directory or file
|
||||||
|
storage_type: Force a specific storage type ("h5", "bin"),
|
||||||
|
or None for auto-detection
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the loaded storage is missing required keys.
|
||||||
|
"""
|
||||||
|
if storage_type is None:
|
||||||
|
storage_type = detect_format(load_path)
|
||||||
|
self.storage = StoreFactory.create(storage_type)
|
||||||
|
self._load_path = load_path
|
||||||
|
self.storage.load(load_path)
|
||||||
|
self._validate_keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||||
|
if self.storage is None:
|
||||||
|
return 0
|
||||||
|
return len(self.storage)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
"""Return the available data keys."""
|
||||||
|
if self.storage is None:
|
||||||
|
return []
|
||||||
|
return self.storage.keys
|
||||||
|
|
||||||
|
def get_index(self, index: int) -> tuple:
|
||||||
|
"""Calculate begin and end indices for a sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: Sample index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (begin_idx, end_idx)
|
||||||
|
"""
|
||||||
|
if self.storage is None:
|
||||||
|
raise RuntimeError("Dataset not loaded, call load() first")
|
||||||
|
total = len(self.storage)
|
||||||
|
if total <= self.window_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Data too short: {total} tokens <= window_size {self.window_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
begin_idx = min(index * self.stride, total - 1 - self.window_size)
|
||||||
|
end_idx = min(begin_idx + self.window_size, total - 1)
|
||||||
|
|
||||||
|
return begin_idx, end_idx
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
"""Get a single sample by index.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self.storage is None:
|
||||||
|
return 0
|
||||||
|
total = len(self.storage)
|
||||||
|
if total <= self.window_size:
|
||||||
|
return 0
|
||||||
|
return (total - 1 - self.window_size) // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
|
"""Factory class for creating dataset instances.
|
||||||
|
|
||||||
|
Supports decorator-based registration for extensible dataset types.
|
||||||
|
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
||||||
|
when their classes are defined with the decorator.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
@DatasetFactory.register("custom")
|
||||||
|
class CustomDataset(BaseDataset):
|
||||||
|
...
|
||||||
|
|
||||||
|
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, dataset_cls: type):
|
||||||
|
"""Validate that the dataset class inherits from BaseDataset."""
|
||||||
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
||||||
|
"""Create a dataset instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||||
|
window_size: Window size for data sampling
|
||||||
|
stride: Stride between consecutive samples
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset instance
|
||||||
|
"""
|
||||||
|
return super().create(train_type, window_size, stride)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
train_type: str,
|
||||||
|
load_path: str,
|
||||||
|
window_size: int,
|
||||||
|
stride: Optional[int] = None,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
) -> "BaseDataset":
|
||||||
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_type: Type of training dataset
|
||||||
|
load_path: Path to the data file
|
||||||
|
window_size: Window size for data sampling
|
||||||
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
|
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded dataset instance
|
||||||
|
"""
|
||||||
|
if stride is None:
|
||||||
|
stride = window_size
|
||||||
|
|
||||||
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
|
dataset.load(load_path, storage_type=storage_type)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available_types(cls) -> list:
|
||||||
|
"""Return list of registered dataset type names."""
|
||||||
|
return cls.list_registered()
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("seq")
|
||||||
|
class SEQDataset(BaseDataset):
|
||||||
|
"""Dataset for sequential next-token prediction training."""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["sequence"]
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||||
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||||
|
|
||||||
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("sft")
|
||||||
|
class SFTDataset(BaseDataset):
|
||||||
|
"""Dataset for supervised fine-tuning with loss masking."""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["sequence", "loss_mask"]
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("dpo")
|
||||||
|
class DPODataset(BaseDataset):
|
||||||
|
"""Dataset for Direct Preference Optimization training."""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||||
|
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||||
|
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"chosen": chosen,
|
||||||
|
"rejected": rejected,
|
||||||
|
"chosen_mask": chosen_mask,
|
||||||
|
"rejected_mask": rejected_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("grpo")
|
||||||
|
class GRPODataset(BaseDataset):
|
||||||
|
"""Dataset for Group Relative Policy Optimization training."""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int, stride: int):
|
||||||
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["prompts", "responses", "masks", "rewards"]
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
||||||
|
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
||||||
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompts": prompts,
|
||||||
|
"responses": responses,
|
||||||
|
"masks": masks,
|
||||||
|
"rewards": rewards,
|
||||||
|
}
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class ResumableDistributedSampler(Sampler[int]):
|
class ResumableDistributedSampler(Sampler[int]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_source: Dataset,
|
data_source: Dataset,
|
||||||
start_epoch: int=0,
|
start_epoch: int = 0,
|
||||||
start_iter: int=0,
|
start_iter: int = 0,
|
||||||
seed: int=42,
|
seed: int = 42,
|
||||||
drop_last: bool=False,
|
drop_last: bool = False,
|
||||||
shuffle: bool=True,
|
shuffle: bool = True,
|
||||||
process_group: Optional[dist.ProcessGroup]=None,
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
):
|
):
|
||||||
self.epoch = start_epoch
|
self.epoch = start_epoch
|
||||||
self.iter = start_iter
|
self.iter = start_iter
|
||||||
|
|
@ -40,9 +40,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
|
self.iter = self.iter % self.num_samples_per_replica
|
||||||
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
|
@ -58,10 +59,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
padding_size = self.total_size - len(indices)
|
padding_size = self.total_size - len(indices)
|
||||||
indices += indices[:padding_size]
|
indices += indices[:padding_size]
|
||||||
|
|
||||||
local_indices = indices[self.rank:self.total_size:self.num_replicas]
|
local_indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||||
|
|
||||||
self.iter = self.iter % self.num_samples_per_replica
|
self.iter = self.iter % self.num_samples_per_replica
|
||||||
self._indices = local_indices[self.iter:]
|
self._indices = local_indices[self.iter :]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self._indices is None:
|
if self._indices is None:
|
||||||
|
|
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _remaining(self):
|
||||||
|
remaining = self.num_samples_per_replica - self.iter
|
||||||
|
return max(remaining, 0)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples_per_replica
|
return self._remaining
|
||||||
|
|
@ -0,0 +1,264 @@
|
||||||
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
|
Layers:
|
||||||
|
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
||||||
|
return Dict[str, List[Tensor]] — format-specific, no state
|
||||||
|
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||||
|
Dict[str, List[Tensor]] per key via _normalize(),
|
||||||
|
fetch() uses bisect across segments — no forced concat
|
||||||
|
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||||
|
|
||||||
|
Key properties:
|
||||||
|
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||||
|
datasets larger than RAM
|
||||||
|
- Explicit length: _length = min(total elements across keys), set at load,
|
||||||
|
__len__ returns O(1)
|
||||||
|
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||||
|
workers share OS page-cache pages
|
||||||
|
"""
|
||||||
|
|
||||||
|
import bisect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||||
|
with h5py.File(full_file_path, "w") as f:
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
grp = f.create_group(key)
|
||||||
|
for idx, tensor in enumerate(tensors):
|
||||||
|
arr = tensor.cpu().numpy()
|
||||||
|
grp.create_dataset(f"data_{idx}", data=arr)
|
||||||
|
|
||||||
|
|
||||||
|
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
||||||
|
root_path = Path(file_path)
|
||||||
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||||
|
|
||||||
|
for h5_file in h5_files:
|
||||||
|
with h5py.File(h5_file, "r") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
grp = f[key]
|
||||||
|
dsets = []
|
||||||
|
for dset_name in grp.keys():
|
||||||
|
dset = grp[dset_name]
|
||||||
|
tensor = torch.from_numpy(dset[:])
|
||||||
|
if share_memory:
|
||||||
|
tensor = tensor.share_memory_()
|
||||||
|
dsets.append(tensor)
|
||||||
|
|
||||||
|
if tensor_group.get(key) is None:
|
||||||
|
tensor_group[key] = []
|
||||||
|
tensor_group[key].extend(dsets)
|
||||||
|
|
||||||
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
|
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
meta = {}
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
cat = torch.cat(tensors, dim=0)
|
||||||
|
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||||
|
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||||
|
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
||||||
|
json.dump(meta, f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||||
|
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
segments: Dict[str, List[Tensor]] = {}
|
||||||
|
for key, info in meta.items():
|
||||||
|
arr = np.memmap(
|
||||||
|
os.path.join(file_path, f"{key}.bin"),
|
||||||
|
dtype=info["dtype"],
|
||||||
|
mode="r+",
|
||||||
|
shape=tuple(info["shape"]),
|
||||||
|
)
|
||||||
|
segments[key] = [torch.from_numpy(arr)]
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def detect_format(load_path: str) -> str:
|
||||||
|
"""Auto-detect storage format from files in the directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_path: Directory or file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Format string ("h5" or "bin")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If no supported data files are found
|
||||||
|
"""
|
||||||
|
root = Path(load_path)
|
||||||
|
if root.is_file():
|
||||||
|
suffix = root.suffix.lower()
|
||||||
|
if suffix in (".h5", ".hdf5"):
|
||||||
|
return "h5"
|
||||||
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
|
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||||
|
if h5_files:
|
||||||
|
return "h5"
|
||||||
|
bin_files = list(root.rglob("*.bin"))
|
||||||
|
if bin_files:
|
||||||
|
has_meta = (root / "meta.json").exists() or len(
|
||||||
|
list(root.rglob("meta.json"))
|
||||||
|
) > 0
|
||||||
|
if has_meta:
|
||||||
|
return "bin"
|
||||||
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
||||||
|
class Store(ABC):
|
||||||
|
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
||||||
|
|
||||||
|
Each key maps to one or more tensor segments (no forced concatenation).
|
||||||
|
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
||||||
|
total element count across all keys.
|
||||||
|
|
||||||
|
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
||||||
|
via ``_normalize()``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._data: Dict[str, List[Tensor]] = {}
|
||||||
|
self._cum: Dict[str, List[int]] = {}
|
||||||
|
self._length: int = 0
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, path: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
return list(self._data.keys())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def fetch(
|
||||||
|
self,
|
||||||
|
begin: int,
|
||||||
|
end: int,
|
||||||
|
keys: Union[str, List[str]],
|
||||||
|
):
|
||||||
|
if not self._data:
|
||||||
|
raise RuntimeError("Store not loaded")
|
||||||
|
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
||||||
|
raise ValueError(
|
||||||
|
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
||||||
|
)
|
||||||
|
if isinstance(keys, str):
|
||||||
|
return self._fetch_key(keys, begin, end)
|
||||||
|
return {k: self._fetch_key(k, begin, end) for k in keys}
|
||||||
|
|
||||||
|
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
||||||
|
"""Fetch slice [begin, end) across potentially multiple segments."""
|
||||||
|
segments = self._data[key]
|
||||||
|
cum = self._cum[key]
|
||||||
|
seg_start = bisect.bisect_right(cum, begin)
|
||||||
|
seg_end = bisect.bisect_left(cum, end)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(seg_start, seg_end + 1):
|
||||||
|
prev = cum[i - 1] if i > 0 else 0
|
||||||
|
s = max(begin - prev, 0)
|
||||||
|
e = min(end - prev, segments[i].shape[0])
|
||||||
|
results.append(segments[i][s:e])
|
||||||
|
|
||||||
|
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
||||||
|
|
||||||
|
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
||||||
|
"""Register segments and pre-compute cumulative lengths.
|
||||||
|
|
||||||
|
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
||||||
|
large datasets. Sets ``self._length`` to the minimum total
|
||||||
|
element count across all keys.
|
||||||
|
"""
|
||||||
|
for key, tensors in raw.items():
|
||||||
|
self._data[key] = tensors
|
||||||
|
cum = []
|
||||||
|
total = 0
|
||||||
|
for t in tensors:
|
||||||
|
total += t.shape[0]
|
||||||
|
cum.append(total)
|
||||||
|
self._cum[key] = cum
|
||||||
|
self._length = (
|
||||||
|
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
||||||
|
if self._cum
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StoreFactory(BaseFactory["Store"]):
|
||||||
|
"""Factory for creating Store instances by type name.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
@StoreFactory.register("custom")
|
||||||
|
class CustomStore(Store):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, store_cls: type):
|
||||||
|
if not issubclass(store_cls, Store):
|
||||||
|
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
||||||
|
|
||||||
|
|
||||||
|
@StoreFactory.register("h5")
|
||||||
|
class H5Store(Store):
|
||||||
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
|
def load(self, path: str):
|
||||||
|
self._normalize(load_h5(path))
|
||||||
|
|
||||||
|
|
||||||
|
@StoreFactory.register("bin")
|
||||||
|
class MmapStore(Store):
|
||||||
|
"""Memory-mapped binary storage backend.
|
||||||
|
|
||||||
|
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
||||||
|
No per-process memory duplication — all DataLoader workers share the
|
||||||
|
same OS page-cache pages.
|
||||||
|
|
||||||
|
Format on disk::
|
||||||
|
|
||||||
|
data_root/
|
||||||
|
meta.json # {key: {shape, dtype}, ...}
|
||||||
|
<key>.bin # raw numpy array, one per key
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load(self, path: str):
|
||||||
|
self._mmap_refs = []
|
||||||
|
root = Path(path)
|
||||||
|
all_raw: Dict[str, List[Tensor]] = {}
|
||||||
|
meta_paths = list(root.rglob("meta.json"))
|
||||||
|
for meta_path in meta_paths:
|
||||||
|
raw = load_bin(str(meta_path.parent))
|
||||||
|
for key, tensors in raw.items():
|
||||||
|
if key not in all_raw:
|
||||||
|
all_raw[key] = []
|
||||||
|
all_raw[key].extend(tensors)
|
||||||
|
if not meta_paths:
|
||||||
|
raise FileNotFoundError(f"No meta.json found under {path}")
|
||||||
|
self._normalize(all_raw)
|
||||||
|
for tensors in self._data.values():
|
||||||
|
self._mmap_refs.extend(tensors)
|
||||||
|
|
@ -0,0 +1,226 @@
|
||||||
|
"""Base factory class for extensible component registration."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
"""Flexible registry for component classes with category and priority support.
|
||||||
|
|
||||||
|
This registry stores component classes with optional metadata (category, priority).
|
||||||
|
It provides methods for registration, retrieval, and listing with filtering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._entries = {} # name -> (component_cls, category, priority)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
component_cls: Type,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
):
|
||||||
|
"""Register a component class with optional category and priority."""
|
||||||
|
if name in self._entries:
|
||||||
|
raise ValueError(f"Component '{name}' is already registered")
|
||||||
|
self._entries[name] = (component_cls, category, priority)
|
||||||
|
|
||||||
|
def get(self, name: str) -> Type:
|
||||||
|
"""Get component class by name."""
|
||||||
|
if name not in self._entries:
|
||||||
|
raise KeyError(f"Component '{name}' not found in registry")
|
||||||
|
return self._entries[name][0]
|
||||||
|
|
||||||
|
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
|
||||||
|
"""Get component class with its metadata."""
|
||||||
|
entry = self._entries.get(name)
|
||||||
|
if entry is None:
|
||||||
|
raise KeyError(f"Component '{name}' not found in registry")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def contains(self, name: str) -> bool:
|
||||||
|
"""Check if a name is registered."""
|
||||||
|
return name in self._entries
|
||||||
|
|
||||||
|
def list_names(self) -> List[str]:
|
||||||
|
"""Return list of registered component names."""
|
||||||
|
return sorted(self._entries.keys())
|
||||||
|
|
||||||
|
def list_by_category(self, category: str) -> List[str]:
|
||||||
|
"""Return names of components belonging to a specific category."""
|
||||||
|
return sorted(
|
||||||
|
name for name, (_, cat, _) in self._entries.items() if cat == category
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_by_priority(self, reverse: bool = False) -> List[str]:
|
||||||
|
"""Return names sorted by priority (default ascending)."""
|
||||||
|
return sorted(
|
||||||
|
self._entries.keys(),
|
||||||
|
key=lambda name: self._entries[name][2],
|
||||||
|
reverse=reverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
|
||||||
|
"""Return raw entries dictionary."""
|
||||||
|
return self._entries.copy()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFactory(ABC, Generic[T]):
|
||||||
|
"""Generic factory class for component registration and creation.
|
||||||
|
|
||||||
|
This base class provides a decorator-based registration pattern
|
||||||
|
for creating extensible component factories.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
class MyFactory(BaseFactory[MyBaseClass]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@MyFactory.register("custom")
|
||||||
|
class CustomComponent(MyBaseClass):
|
||||||
|
...
|
||||||
|
|
||||||
|
component = MyFactory.create("custom", *args, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Registry
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
cls._registry = Registry()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(
|
||||||
|
cls, name: str, category: Optional[str] = None, priority: int = 0
|
||||||
|
) -> Callable[[Type[T]], Type[T]]:
|
||||||
|
"""Decorator to register a component class with optional category and priority.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registration name for the component
|
||||||
|
category: Optional category for grouping components
|
||||||
|
priority: Priority for ordering (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function that registers the component class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the decorated class doesn't inherit from the base type
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||||
|
cls._validate_component(component_cls)
|
||||||
|
cls._registry.register(
|
||||||
|
name, component_cls, category=category, priority=priority
|
||||||
|
)
|
||||||
|
return component_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, *args, **kwargs) -> T:
|
||||||
|
"""Create a component instance by name.
|
||||||
|
|
||||||
|
Filters kwargs to match the component's __init__ signature,
|
||||||
|
so components don't need to declare **kwargs just to absorb
|
||||||
|
parameters meant for other components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registered name of the component
|
||||||
|
*args: Positional arguments passed to component constructor
|
||||||
|
**kwargs: Keyword arguments passed to component constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Component instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the component name is not registered
|
||||||
|
"""
|
||||||
|
if not cls._registry.contains(name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown component: '{name}'. "
|
||||||
|
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||||
|
)
|
||||||
|
component_cls = cls._registry.get(name)
|
||||||
|
sig = inspect.signature(component_cls.__init__)
|
||||||
|
has_var_kwargs = any(
|
||||||
|
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||||
|
)
|
||||||
|
if not has_var_kwargs:
|
||||||
|
valid = {
|
||||||
|
p.name
|
||||||
|
for p in sig.parameters.values()
|
||||||
|
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
||||||
|
}
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
||||||
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: Type[T]):
|
||||||
|
"""Validate that the component class is valid for this factory.
|
||||||
|
|
||||||
|
Override this method in subclasses to add custom validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_cls: Component class to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the component class is invalid
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_component_class(cls, name: str) -> Type[T]:
|
||||||
|
"""Get the registered component class by name without instantiating it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registered name of the component
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The component class itself
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the component name is not registered
|
||||||
|
"""
|
||||||
|
if not cls._registry.contains(name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown component: '{name}'. "
|
||||||
|
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||||
|
)
|
||||||
|
return cls._registry.get(name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_registered(cls) -> list:
|
||||||
|
"""List all registered component names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered component names
|
||||||
|
"""
|
||||||
|
return cls._registry.list_names()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_registered(cls, name: str) -> bool:
|
||||||
|
"""Check if a component name is registered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Component name to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if registered, False otherwise
|
||||||
|
"""
|
||||||
|
return cls._registry.contains(name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_by_category(cls, category: str) -> List[str]:
|
||||||
|
"""List registered component names in a category."""
|
||||||
|
return cls._registry.list_by_category(category)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
||||||
|
"""List registered component names sorted by priority."""
|
||||||
|
return cls._registry.list_by_priority(reverse)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Registry", "BaseFactory"]
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
"""Inference module for continuous batching.
|
||||||
|
|
||||||
|
Layers:
|
||||||
|
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||||
|
- api/: HTTP orchestration (ProtocolHandler, server)
|
||||||
|
- protocols/: Response builders (OpenAI, Anthropic)
|
||||||
|
- transport/: SSE transport utilities
|
||||||
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||||
|
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from astrai.inference.api import (
|
||||||
|
AnthropicMessage,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatMessage,
|
||||||
|
GenContext,
|
||||||
|
MessagesRequest,
|
||||||
|
ProtocolHandler,
|
||||||
|
StopChecker,
|
||||||
|
app,
|
||||||
|
run_server,
|
||||||
|
)
|
||||||
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||||
|
from astrai.inference.core import (
|
||||||
|
STOP,
|
||||||
|
Allocator,
|
||||||
|
Executor,
|
||||||
|
InferenceScheduler,
|
||||||
|
KVCache,
|
||||||
|
KvcacheView,
|
||||||
|
PagePool,
|
||||||
|
PrefixCache,
|
||||||
|
Storage,
|
||||||
|
Task,
|
||||||
|
TaskManager,
|
||||||
|
TaskStatus,
|
||||||
|
TaskTable,
|
||||||
|
page_hash,
|
||||||
|
)
|
||||||
|
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||||
|
from astrai.inference.sample import (
|
||||||
|
BaseSamplingStrategy,
|
||||||
|
SamplingPipeline,
|
||||||
|
TemperatureStrategy,
|
||||||
|
TopKStrategy,
|
||||||
|
TopPStrategy,
|
||||||
|
sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"InferenceEngine",
|
||||||
|
"GenerationRequest",
|
||||||
|
"InferenceScheduler",
|
||||||
|
"Executor",
|
||||||
|
"STOP",
|
||||||
|
"Task",
|
||||||
|
"TaskManager",
|
||||||
|
"TaskStatus",
|
||||||
|
"Allocator",
|
||||||
|
"KVCache",
|
||||||
|
"KvcacheView",
|
||||||
|
"PagePool",
|
||||||
|
"PrefixCache",
|
||||||
|
"Storage",
|
||||||
|
"TaskTable",
|
||||||
|
"page_hash",
|
||||||
|
"sample",
|
||||||
|
"BaseSamplingStrategy",
|
||||||
|
"TemperatureStrategy",
|
||||||
|
"TopKStrategy",
|
||||||
|
"TopPStrategy",
|
||||||
|
"SamplingPipeline",
|
||||||
|
"ProtocolHandler",
|
||||||
|
"StopChecker",
|
||||||
|
"GenContext",
|
||||||
|
"OpenAIResponseBuilder",
|
||||||
|
"AnthropicResponseBuilder",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatCompletionRequest",
|
||||||
|
"AnthropicMessage",
|
||||||
|
"MessagesRequest",
|
||||||
|
"app",
|
||||||
|
"run_server",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||||
|
from astrai.inference.api.server import (
|
||||||
|
AnthropicMessage,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatMessage,
|
||||||
|
MessagesRequest,
|
||||||
|
app,
|
||||||
|
run_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ProtocolHandler",
|
||||||
|
"StopChecker",
|
||||||
|
"GenContext",
|
||||||
|
"AnthropicMessage",
|
||||||
|
"ChatCompletionRequest",
|
||||||
|
"ChatMessage",
|
||||||
|
"MessagesRequest",
|
||||||
|
"app",
|
||||||
|
"run_server",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""Anthropic message completion response builder."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import (
|
||||||
|
GenContext,
|
||||||
|
ResponseBuilder,
|
||||||
|
StopInfo,
|
||||||
|
sse_event,
|
||||||
|
)
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
messages: List[Dict[str, str]] = []
|
||||||
|
system = getattr(request, "system", None)
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
for m in request.messages:
|
||||||
|
text = _extract_text(m.content)
|
||||||
|
if text:
|
||||||
|
messages.append({"role": m.role, "content": text})
|
||||||
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
ctx = GenContext(
|
||||||
|
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||||
|
created=int(time.time()),
|
||||||
|
model=request.model,
|
||||||
|
prompt_tokens=0,
|
||||||
|
)
|
||||||
|
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||||
|
return prompt, ctx, stop_sequences
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [],
|
||||||
|
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
event="message_start",
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": {"type": "text", "text": ""},
|
||||||
|
},
|
||||||
|
event="content_block_start",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
return sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": token},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
events: List[str] = []
|
||||||
|
if stop.matched:
|
||||||
|
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
||||||
|
unyielded = trimmed[len(stop.yielded) :]
|
||||||
|
if unyielded:
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": unyielded},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{"type": "content_block_stop", "index": 0},
|
||||||
|
event="content_block_stop",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": {
|
||||||
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||||
|
"stop_sequence": stop.matched,
|
||||||
|
},
|
||||||
|
"usage": {"output_tokens": ctx.completion_tokens},
|
||||||
|
},
|
||||||
|
event="message_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if stop.matched:
|
||||||
|
content = content[: content.rfind(stop.matched)]
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [{"type": "text", "text": content}],
|
||||||
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||||
|
"stop_sequence": stop.matched,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": ctx.prompt_tokens,
|
||||||
|
"output_tokens": ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
"""OpenAI chat completion response builder."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import (
|
||||||
|
GenContext,
|
||||||
|
ResponseBuilder,
|
||||||
|
StopInfo,
|
||||||
|
sse_event,
|
||||||
|
)
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_UNSUPPORTED_PARAMS = (
|
||||||
|
"n",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"logit_bias",
|
||||||
|
"user",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
self._model = request.model
|
||||||
|
|
||||||
|
for param in _UNSUPPORTED_PARAMS:
|
||||||
|
value = getattr(request, param, None)
|
||||||
|
fields = getattr(type(request), "model_fields", {})
|
||||||
|
default = fields[param].default if param in fields else None
|
||||||
|
if value is not None and value != default:
|
||||||
|
logger.warning(
|
||||||
|
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||||
|
param,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
if value is not None and value != default:
|
||||||
|
logger.warning(
|
||||||
|
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||||
|
param,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = GenContext(
|
||||||
|
resp_id=self._resp_id,
|
||||||
|
created=int(time.time()),
|
||||||
|
model=self._model,
|
||||||
|
prompt_tokens=0,
|
||||||
|
)
|
||||||
|
stop = request.stop
|
||||||
|
stop_sequences = (
|
||||||
|
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||||
|
)
|
||||||
|
return prompt, ctx, stop_sequences
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant"},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
return sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 0,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
||||||
|
|
||||||
|
ProtocolHandler orchestrates the async generation loop and delegates
|
||||||
|
protocol-specific formatting to a ResponseBuilder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
|
lines: List[str] = []
|
||||||
|
if event:
|
||||||
|
lines.append(f"event: {event}")
|
||||||
|
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||||||
|
lines.append("")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def sse_done() -> str:
|
||||||
|
return "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenContext:
|
||||||
|
"""Per-generation metadata passed to builder format methods."""
|
||||||
|
|
||||||
|
resp_id: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StopInfo:
|
||||||
|
"""Stop-check result passed to format_stream_end / format_response."""
|
||||||
|
|
||||||
|
matched: Optional[str] = None
|
||||||
|
body: str = ""
|
||||||
|
yielded: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class StopChecker:
|
||||||
|
"""Scans accumulated text for stop sequence matches."""
|
||||||
|
|
||||||
|
def __init__(self, sequences: List[str]):
|
||||||
|
self._sequences = [s for s in sequences if s]
|
||||||
|
|
||||||
|
def check(self, text: str) -> Optional[str]:
|
||||||
|
for seq in self._sequences:
|
||||||
|
if seq in text:
|
||||||
|
return seq
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseBuilder(ABC):
|
||||||
|
"""Interface for protocol-specific response formatting.
|
||||||
|
|
||||||
|
A new protocol requires one concrete builder implementing 5 methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
"""SSE events that open the stream."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
"""SSE event for a single generated token."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
"""SSE events that close the stream."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""JSON response body for non-streaming mode."""
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolHandler:
|
||||||
|
"""Orchestrates the generation loop, delegates formatting to a builder.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||||
|
response = await handler.handle()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
||||||
|
):
|
||||||
|
self.request = request
|
||||||
|
self.engine = engine
|
||||||
|
self.builder = builder
|
||||||
|
|
||||||
|
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||||
|
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
||||||
|
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
||||||
|
|
||||||
|
agen = self.engine.generate_async(
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=self.request.max_tokens,
|
||||||
|
temperature=self.request.temperature,
|
||||||
|
top_p=self.request.top_p,
|
||||||
|
top_k=self.request.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.request.stream:
|
||||||
|
return self._handle_stream(agen, ctx, stop_sequences)
|
||||||
|
else:
|
||||||
|
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
||||||
|
|
||||||
|
def _handle_stream(
|
||||||
|
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||||
|
) -> StreamingResponse:
|
||||||
|
checker = StopChecker(stop_sequences)
|
||||||
|
|
||||||
|
async def event_stream():
|
||||||
|
for event in self.builder.format_stream_start(ctx):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
body = ""
|
||||||
|
yielded = ""
|
||||||
|
matched = None
|
||||||
|
async for token in agen:
|
||||||
|
body += token
|
||||||
|
|
||||||
|
matched = checker.check(body)
|
||||||
|
if matched:
|
||||||
|
break
|
||||||
|
|
||||||
|
ctx.completion_tokens += 1
|
||||||
|
yield self.builder.format_chunk(token)
|
||||||
|
yielded += token
|
||||||
|
|
||||||
|
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||||
|
for event in self.builder.format_stream_end(ctx, stop):
|
||||||
|
yield event
|
||||||
|
yield sse_done()
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_stream(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_non_stream(
|
||||||
|
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
checker = StopChecker(stop_sequences)
|
||||||
|
chunks: List[str] = []
|
||||||
|
body = ""
|
||||||
|
matched = None
|
||||||
|
|
||||||
|
async for token in agen:
|
||||||
|
chunks.append(token)
|
||||||
|
body += token
|
||||||
|
|
||||||
|
matched = checker.check(body)
|
||||||
|
if matched:
|
||||||
|
break
|
||||||
|
|
||||||
|
ctx.completion_tokens += 1
|
||||||
|
|
||||||
|
content = "".join(chunks)
|
||||||
|
stop = StopInfo(matched=matched, body=body)
|
||||||
|
return self.builder.format_response(ctx, content, stop)
|
||||||
|
|
@ -0,0 +1,169 @@
|
||||||
|
"""
|
||||||
|
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
||||||
|
|
||||||
|
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||||
|
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||||
|
from astrai.inference.api.protocol import ProtocolHandler
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
"""OpenAI Chat Completion API request body."""
|
||||||
|
|
||||||
|
model: str = "astrai"
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||||
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
top_k: Optional[int] = Field(default=50, ge=1)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||||
|
n: Optional[int] = Field(default=1, ge=1)
|
||||||
|
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||||
|
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Union[str, List[Dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
class MessagesRequest(BaseModel):
|
||||||
|
"""Anthropic Messages API request body."""
|
||||||
|
|
||||||
|
model: str = "astrai"
|
||||||
|
max_tokens: int = Field(default=1024, ge=1)
|
||||||
|
messages: List[AnthropicMessage]
|
||||||
|
system: Optional[str] = None
|
||||||
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||||
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
top_k: Optional[int] = Field(default=50, ge=1)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
config = app.state.server_config
|
||||||
|
if not config.get("_test", False):
|
||||||
|
try:
|
||||||
|
app.state.engine = _create_engine(**config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
yield
|
||||||
|
if app.state.engine:
|
||||||
|
app.state.engine.shutdown()
|
||||||
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_engine(
|
||||||
|
param_path: Optional[Path] = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
) -> InferenceEngine:
|
||||||
|
if param_path is None:
|
||||||
|
param_path = _project_root / "params"
|
||||||
|
if not param_path.exists():
|
||||||
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
|
model = AutoModel.from_pretrained(param_path)
|
||||||
|
model.to(device=device, dtype=dtype)
|
||||||
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
|
engine = InferenceEngine(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
)
|
||||||
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine() -> InferenceEngine:
|
||||||
|
engine = app.state.engine
|
||||||
|
if engine is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"model_loaded": app.state.engine is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats")
|
||||||
|
async def get_stats():
|
||||||
|
return _get_engine().get_stats()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
|
engine = _get_engine()
|
||||||
|
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||||
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/messages")
|
||||||
|
async def create_message(request: MessagesRequest):
|
||||||
|
engine = _get_engine()
|
||||||
|
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||||
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
|
def run_server(
|
||||||
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 8000,
|
||||||
|
reload: bool = False,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
param_path: Optional[Path] = None,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
):
|
||||||
|
app.state.server_config = {
|
||||||
|
"device": device,
|
||||||
|
"dtype": dtype,
|
||||||
|
"param_path": param_path,
|
||||||
|
"max_batch_size": max_batch_size,
|
||||||
|
}
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
reload=reload,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""Inference core: cache, executor, scheduler, task management."""
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import (
|
||||||
|
Allocator,
|
||||||
|
KVCache,
|
||||||
|
KvcacheView,
|
||||||
|
PagePool,
|
||||||
|
PrefixCache,
|
||||||
|
Storage,
|
||||||
|
TaskTable,
|
||||||
|
page_hash,
|
||||||
|
)
|
||||||
|
from astrai.inference.core.executor import Executor
|
||||||
|
from astrai.inference.core.scheduler import InferenceScheduler
|
||||||
|
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Allocator",
|
||||||
|
"KVCache",
|
||||||
|
"KvcacheView",
|
||||||
|
"PagePool",
|
||||||
|
"PrefixCache",
|
||||||
|
"Storage",
|
||||||
|
"TaskTable",
|
||||||
|
"page_hash",
|
||||||
|
"Executor",
|
||||||
|
"InferenceScheduler",
|
||||||
|
"STOP",
|
||||||
|
"Task",
|
||||||
|
"TaskManager",
|
||||||
|
"TaskStatus",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,368 @@
|
||||||
|
import threading
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
|
start = page_idx * page_size
|
||||||
|
end = min(start + page_size, len(token_ids))
|
||||||
|
h = 0
|
||||||
|
for i in range(start, end):
|
||||||
|
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Allocator:
|
||||||
|
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
|
||||||
|
|
||||||
|
def __init__(self, n_pages: int):
|
||||||
|
self._free_mask = (1 << n_pages) - 1
|
||||||
|
self._refs: List[int] = [0] * n_pages
|
||||||
|
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||||
|
self.on_evict: Optional[Callable[[int], None]] = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def alloc(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
if self._free_mask:
|
||||||
|
lsb = self._free_mask & -self._free_mask
|
||||||
|
idx = lsb.bit_length() - 1
|
||||||
|
self._free_mask ^= lsb
|
||||||
|
self._refs[idx] = 1
|
||||||
|
return idx
|
||||||
|
if self._lru:
|
||||||
|
idx, _ = self._lru.popitem(last=False)
|
||||||
|
if self.on_evict:
|
||||||
|
self.on_evict(idx)
|
||||||
|
self._refs[idx] = 1
|
||||||
|
self._free_mask &= ~(1 << idx)
|
||||||
|
return idx
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def free(self, idx: int, keep_cached: bool = False):
|
||||||
|
with self._lock:
|
||||||
|
self._refs[idx] -= 1
|
||||||
|
if self._refs[idx] == 0:
|
||||||
|
if keep_cached:
|
||||||
|
self._lru[idx] = None
|
||||||
|
else:
|
||||||
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
|
def inc_ref(self, idx: int):
|
||||||
|
with self._lock:
|
||||||
|
self._refs[idx] += 1
|
||||||
|
self._lru.pop(idx, None)
|
||||||
|
|
||||||
|
def ref_count(self, idx: int) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return self._refs[idx]
|
||||||
|
|
||||||
|
def touch(self, idx: int):
|
||||||
|
with self._lock:
|
||||||
|
self._lru.move_to_end(idx)
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixCache:
|
||||||
|
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||||
|
|
||||||
|
def __init__(self, page_size: int):
|
||||||
|
self._page_size = page_size
|
||||||
|
self._page_to_hash: Dict[int, int] = {}
|
||||||
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def evict(self, idx: int):
|
||||||
|
with self._lock:
|
||||||
|
h = self._page_to_hash.pop(idx, None)
|
||||||
|
if h is not None:
|
||||||
|
self._hash_to_page.pop(h, None)
|
||||||
|
|
||||||
|
def has_page(self, idx: int) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
return idx in self._page_to_hash
|
||||||
|
|
||||||
|
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||||
|
with self._lock:
|
||||||
|
full_pages = len(token_ids) // self._page_size
|
||||||
|
hits: List[int] = []
|
||||||
|
for i in range(full_pages):
|
||||||
|
h = page_hash(token_ids, i, self._page_size)
|
||||||
|
p = self._hash_to_page.get(h)
|
||||||
|
if p is None:
|
||||||
|
break
|
||||||
|
hits.append(p)
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||||
|
with self._lock:
|
||||||
|
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||||
|
old_h = self._page_to_hash.pop(page_idx, None)
|
||||||
|
if old_h is not None:
|
||||||
|
self._hash_to_page.pop(old_h, None)
|
||||||
|
self._page_to_hash[page_idx] = h
|
||||||
|
self._hash_to_page[h] = page_idx
|
||||||
|
|
||||||
|
|
||||||
|
class PagePool:
|
||||||
|
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
|
||||||
|
|
||||||
|
def __init__(self, allocator: Allocator, prefix: PrefixCache):
|
||||||
|
self._alloc = allocator
|
||||||
|
self._prefix = prefix
|
||||||
|
self._alloc.on_evict = prefix.evict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allocator(self) -> Allocator:
|
||||||
|
return self._alloc
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefix(self) -> PrefixCache:
|
||||||
|
return self._prefix
|
||||||
|
|
||||||
|
def alloc(self) -> int:
|
||||||
|
return self._alloc.alloc()
|
||||||
|
|
||||||
|
def free(self, idx: int):
|
||||||
|
keep = self._prefix.has_page(idx)
|
||||||
|
self._alloc.free(idx, keep_cached=keep)
|
||||||
|
if not keep:
|
||||||
|
self._prefix.evict(idx)
|
||||||
|
|
||||||
|
def inc_ref(self, idx: int):
|
||||||
|
self._alloc.inc_ref(idx)
|
||||||
|
|
||||||
|
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||||
|
hits = self._prefix.lookup(token_ids)
|
||||||
|
for p in hits:
|
||||||
|
self._alloc.touch(p)
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||||
|
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTable:
|
||||||
|
"""Maps task_ids to page tables and cached token counts."""
|
||||||
|
|
||||||
|
def __init__(self, page_size: int):
|
||||||
|
self._page_size = page_size
|
||||||
|
self._pages: Dict[str, List[int]] = {}
|
||||||
|
self._cached: Dict[str, int] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def set(self, task_id: str, page_table: List[int], cached: int):
|
||||||
|
with self._lock:
|
||||||
|
self._pages[task_id] = page_table
|
||||||
|
self._cached[task_id] = cached
|
||||||
|
|
||||||
|
def get(self, task_id: str) -> List[int]:
|
||||||
|
with self._lock:
|
||||||
|
return self._pages.get(task_id, [])
|
||||||
|
|
||||||
|
def get_cached(self, task_id: str) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return self._cached.get(task_id, 0)
|
||||||
|
|
||||||
|
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||||
|
with self._lock:
|
||||||
|
pages = self._pages.pop(task_id, [])
|
||||||
|
cached = self._cached.pop(task_id, 0)
|
||||||
|
return pages, cached
|
||||||
|
|
||||||
|
def get_ref(self, task_id: str) -> List[int]:
|
||||||
|
with self._lock:
|
||||||
|
return self._pages.setdefault(task_id, [])
|
||||||
|
|
||||||
|
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
|
with self._lock:
|
||||||
|
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||||
|
max_pages = max((len(s) for s in states), default=0)
|
||||||
|
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||||
|
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
class Storage:
|
||||||
|
"""KV-cache tensor storage with paged write/gather."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_layers: int,
|
||||||
|
n_pages: int,
|
||||||
|
page_size: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self.page_size = page_size
|
||||||
|
self.k_cache = torch.empty(
|
||||||
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.v_cache = torch.empty(
|
||||||
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
page_table: Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
k: Tensor,
|
||||||
|
v: Tensor,
|
||||||
|
):
|
||||||
|
seq_len = k.size(1)
|
||||||
|
if seq_len == 0:
|
||||||
|
return
|
||||||
|
page_size = self.page_size
|
||||||
|
written = 0
|
||||||
|
first_page = start_pos // page_size
|
||||||
|
last_page = (start_pos + seq_len - 1) // page_size
|
||||||
|
for pi in range(first_page, last_page + 1):
|
||||||
|
phys_pages = page_table[:, pi]
|
||||||
|
page_start = pi * page_size
|
||||||
|
write_start = max(page_start, start_pos)
|
||||||
|
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||||
|
offset = write_start - page_start
|
||||||
|
chunk = write_end - write_start
|
||||||
|
valid = phys_pages >= 0
|
||||||
|
if not valid.all():
|
||||||
|
if valid.any():
|
||||||
|
valid_pages = phys_pages[valid]
|
||||||
|
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
|
||||||
|
valid, written : written + chunk
|
||||||
|
]
|
||||||
|
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
|
||||||
|
valid, written : written + chunk
|
||||||
|
]
|
||||||
|
written += chunk
|
||||||
|
continue
|
||||||
|
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||||
|
:, written : written + chunk
|
||||||
|
]
|
||||||
|
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
||||||
|
:, written : written + chunk
|
||||||
|
]
|
||||||
|
written += chunk
|
||||||
|
|
||||||
|
def gather(
|
||||||
|
self, layer_id: int, page_table: Tensor, total_len: int
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
safe = page_table.clamp(min=0)
|
||||||
|
k = self.k_cache[layer_id, safe]
|
||||||
|
v = self.v_cache[layer_id, safe]
|
||||||
|
k = k.flatten(1, 2)
|
||||||
|
v = v.flatten(1, 2)
|
||||||
|
if (page_table < 0).any():
|
||||||
|
invalid = (
|
||||||
|
(page_table < 0)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand(-1, -1, self.page_size)
|
||||||
|
.flatten(1, 2)
|
||||||
|
)
|
||||||
|
invalid = invalid[:, :, None, None].expand_as(k)
|
||||||
|
k = k.masked_fill(invalid, 0.0)
|
||||||
|
v = v.masked_fill(invalid, 0.0)
|
||||||
|
k = k[:, :total_len]
|
||||||
|
v = v[:, :total_len]
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
|
class KvcacheView:
|
||||||
|
"""Bundles Storage + page_table + total_len for attention layers."""
|
||||||
|
|
||||||
|
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
|
||||||
|
self._storage = storage
|
||||||
|
self._page_table = page_table
|
||||||
|
self._total_len = total_len
|
||||||
|
|
||||||
|
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
||||||
|
start_pos = self._total_len - k.size(1)
|
||||||
|
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
|
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
return self._storage.gather(layer_id, self._page_table, self._total_len)
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
"""Facade: page management + KV-cache I/O for continuous batching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_layers: int,
|
||||||
|
n_pages: int,
|
||||||
|
page_size: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self.page_size = page_size
|
||||||
|
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||||
|
self._table = TaskTable(page_size)
|
||||||
|
self._storage = Storage(
|
||||||
|
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||||
|
hits = self._pool.lookup(prompt_ids)
|
||||||
|
cached = len(hits) * self.page_size
|
||||||
|
for p in hits:
|
||||||
|
self._pool.inc_ref(p)
|
||||||
|
|
||||||
|
remaining = len(prompt_ids) - cached
|
||||||
|
n_new = (
|
||||||
|
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||||
|
)
|
||||||
|
new_pages: List[int] = []
|
||||||
|
if n_new > 0:
|
||||||
|
for _ in range(n_new):
|
||||||
|
p = self._pool.alloc()
|
||||||
|
if p < 0:
|
||||||
|
for hp in hits:
|
||||||
|
self._pool.free(hp)
|
||||||
|
for np in new_pages:
|
||||||
|
self._pool.free(np)
|
||||||
|
return False
|
||||||
|
new_pages.append(p)
|
||||||
|
|
||||||
|
self._table.set(task_id, hits + new_pages, cached)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def task_free(self, task_id: str):
|
||||||
|
page_table, _ = self._table.pop(task_id)
|
||||||
|
for idx in page_table:
|
||||||
|
self._pool.free(idx)
|
||||||
|
|
||||||
|
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||||
|
page_table = self._table.get(task_id)
|
||||||
|
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
||||||
|
while len(page_table) < needed:
|
||||||
|
p = self._pool.alloc()
|
||||||
|
if p < 0:
|
||||||
|
return False
|
||||||
|
page_table.append(p)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def task_cached(self, task_id: str) -> int:
|
||||||
|
return self._table.get_cached(task_id)
|
||||||
|
|
||||||
|
def task_record_hashes(
|
||||||
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||||
|
):
|
||||||
|
page_table = self._table.get(task_id)
|
||||||
|
full_pages = len(prompt_ids) // self.page_size
|
||||||
|
for i in range(start_logical_page, full_pages):
|
||||||
|
self._pool.record(page_table[i], prompt_ids, i)
|
||||||
|
|
||||||
|
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
|
return self._table.table_tensor(task_ids, device)
|
||||||
|
|
||||||
|
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
|
||||||
|
return KvcacheView(self._storage, page_table, total_len)
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import KVCache
|
||||||
|
from astrai.inference.core.task import Task
|
||||||
|
from astrai.inference.sample import sample
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Executor:
|
||||||
|
"""Model forward passes for prefill and decode phases."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: AutoModel,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
page_cache: KVCache,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.page_cache = page_cache
|
||||||
|
self.device = device or next(model.parameters()).device
|
||||||
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
||||||
|
if start_pos >= prompt_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids=torch.arange(
|
||||||
|
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(batch_sz, -1),
|
||||||
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
||||||
|
if not tasks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = torch.tensor(
|
||||||
|
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
total_len = position_ids.max().item() + 1
|
||||||
|
|
||||||
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
|
|
||||||
|
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
||||||
|
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
||||||
|
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids.unsqueeze(1),
|
||||||
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||||
|
position_ids=position_ids.unsqueeze(1),
|
||||||
|
)
|
||||||
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
|
return sample(
|
||||||
|
logits,
|
||||||
|
temperature=temperatures,
|
||||||
|
top_k=top_ks,
|
||||||
|
top_p=top_ps,
|
||||||
|
).tolist()
|
||||||
|
|
@ -0,0 +1,212 @@
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import KVCache
|
||||||
|
from astrai.inference.core.executor import Executor
|
||||||
|
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceScheduler:
|
||||||
|
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: AutoModel,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
max_prompt_len: int = 2048,
|
||||||
|
page_size: int = 64,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
config = model.config
|
||||||
|
|
||||||
|
if max_seq_len is not None:
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
elif config.max_len is not None:
|
||||||
|
self.max_seq_len = config.max_len
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"max_seq_len must be provided either as argument "
|
||||||
|
"or in model config (config.max_len)"
|
||||||
|
)
|
||||||
|
self.device = device or next(model.parameters()).device
|
||||||
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
n_pages = (
|
||||||
|
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||||
|
) // page_size
|
||||||
|
|
||||||
|
self._page_cache = KVCache(
|
||||||
|
config.n_layers,
|
||||||
|
n_pages,
|
||||||
|
page_size,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.dim // config.n_heads,
|
||||||
|
self.device,
|
||||||
|
self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._task_mgr = TaskManager(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._executor = Executor(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
page_cache=self._page_cache,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
self._fatal_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
def add_task(self, prompt: str, **kwargs) -> str:
|
||||||
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
||||||
|
def remove_task(self, task_id: str):
|
||||||
|
for task in self._task_mgr.remove_task(task_id):
|
||||||
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
return self._task_mgr.get_stats()
|
||||||
|
|
||||||
|
def _run_generation_loop(self):
|
||||||
|
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||||
|
try:
|
||||||
|
while self._running:
|
||||||
|
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
||||||
|
for task in finished:
|
||||||
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
||||||
|
active = self._task_mgr.get_active_tasks()
|
||||||
|
available = self._task_mgr.max_batch_size - len(active)
|
||||||
|
if available > 0:
|
||||||
|
candidates = self._task_mgr.pull_candidates(available)
|
||||||
|
failed = []
|
||||||
|
for task in candidates:
|
||||||
|
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
|
||||||
|
self._task_mgr.activate(task)
|
||||||
|
else:
|
||||||
|
failed.append(task)
|
||||||
|
if failed:
|
||||||
|
self._task_mgr.return_to_waiting(failed)
|
||||||
|
|
||||||
|
if not self._task_mgr.has_work():
|
||||||
|
self._task_mgr.wait_for_tasks(timeout=1.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_prefill = [
|
||||||
|
t
|
||||||
|
for t in self._task_mgr.get_active_tasks()
|
||||||
|
if t.output_tokens == 0
|
||||||
|
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
||||||
|
]
|
||||||
|
if to_prefill:
|
||||||
|
for t in to_prefill:
|
||||||
|
t.input_tokens = len(t.prompt_ids)
|
||||||
|
|
||||||
|
groups: Dict[Tuple[int, int], List[Task]] = {}
|
||||||
|
for t in to_prefill:
|
||||||
|
key = (
|
||||||
|
len(t.prompt_ids),
|
||||||
|
self._page_cache.task_cached(t.task_id),
|
||||||
|
)
|
||||||
|
groups.setdefault(key, []).append(t)
|
||||||
|
|
||||||
|
for (prompt_len, start_pos), group in groups.items():
|
||||||
|
self._executor.execute_prefill(group, prompt_len, start_pos)
|
||||||
|
start_logical_page = start_pos // self._page_cache.page_size
|
||||||
|
for t in group:
|
||||||
|
self._page_cache.task_record_hashes(
|
||||||
|
t.task_id,
|
||||||
|
t.prompt_ids,
|
||||||
|
start_logical_page=start_logical_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_groups: Dict[int, List[Task]] = {}
|
||||||
|
for t in self._task_mgr.get_active_tasks():
|
||||||
|
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||||
|
|
||||||
|
if pos_groups:
|
||||||
|
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
||||||
|
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
|
||||||
|
|
||||||
|
valid: List[Task] = []
|
||||||
|
for t in group:
|
||||||
|
if self._page_cache.task_extend(t.task_id, t.next_pos):
|
||||||
|
valid.append(t)
|
||||||
|
else:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
next_tokens = self._executor.execute_decode(valid)
|
||||||
|
|
||||||
|
for t, ntok in zip(valid, next_tokens):
|
||||||
|
t.output_ids.append(ntok)
|
||||||
|
t.output_tokens += 1
|
||||||
|
pos = t.input_tokens + t.output_tokens
|
||||||
|
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(
|
||||||
|
self._task_mgr.tokenizer.decode([ntok])
|
||||||
|
)
|
||||||
|
if not extend_ok:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
|
for t in valid:
|
||||||
|
if t.is_finished(stop_ids):
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._fatal_error = e
|
||||||
|
self._running = False
|
||||||
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
|
for task in self._task_mgr.get_active_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
|
self._page_cache.task_free(task.task_id)
|
||||||
|
for task in self._task_mgr.get_waiting_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
|
self._task_mgr.clear_queues()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if not self._running:
|
||||||
|
self._running = True
|
||||||
|
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||||
|
t.start()
|
||||||
|
self._loop_thread = t
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._running = False
|
||||||
|
self._task_mgr.wake()
|
||||||
|
if hasattr(self, "_loop_thread"):
|
||||||
|
self._loop_thread.join(timeout=2.0)
|
||||||
|
for task in self._task_mgr.get_active_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
|
self._page_cache.task_free(task.task_id)
|
||||||
|
for task in self._task_mgr.get_waiting_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
|
self._task_mgr.clear_queues()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
@ -0,0 +1,209 @@
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Deque, Dict, List, Optional
|
||||||
|
|
||||||
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STOP = object()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
"""Task lifecycle states."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
FINISHED = "finished"
|
||||||
|
ABORTED = "aborted"
|
||||||
|
|
||||||
|
|
||||||
|
class Task:
|
||||||
|
"""Single generation request: prompt, sampling params, output state."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
prompt_ids: List[int],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
):
|
||||||
|
self.task_id = task_id
|
||||||
|
self.prompt_ids = prompt_ids
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
self.status = TaskStatus.PENDING
|
||||||
|
self.output_ids: List[int] = []
|
||||||
|
self.input_tokens: int = 0
|
||||||
|
self.output_tokens: int = 0
|
||||||
|
self.arrival_time = time.time()
|
||||||
|
self.finish_time: Optional[float] = None
|
||||||
|
self.stream_callback = stream_callback
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_pos(self) -> int:
|
||||||
|
return self.input_tokens + len(self.output_ids)
|
||||||
|
|
||||||
|
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||||
|
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
|
||||||
|
return True
|
||||||
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TaskManager:
|
||||||
|
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
max_seq_len: int = 8192,
|
||||||
|
max_prompt_len: int = 512,
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.max_prompt_len = max_prompt_len
|
||||||
|
|
||||||
|
self.waiting_queue: Deque[Task] = deque()
|
||||||
|
self.active_tasks: List[Task] = []
|
||||||
|
|
||||||
|
self._task_event = threading.Event()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
self._total_tasks = 0
|
||||||
|
self._total_tokens = 0
|
||||||
|
|
||||||
|
def add_task(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> str:
|
||||||
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
if len(prompt_ids) > self.max_prompt_len:
|
||||||
|
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||||
|
|
||||||
|
if len(prompt_ids) >= self.max_seq_len:
|
||||||
|
if stream_callback:
|
||||||
|
stream_callback(STOP)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = self.max_seq_len - len(prompt_ids)
|
||||||
|
else:
|
||||||
|
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
task_id=task_id,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=stream_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue.append(task)
|
||||||
|
self._total_tasks += 1
|
||||||
|
|
||||||
|
self._task_event.set()
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def remove_task(self, task_id: str) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||||
|
self.waiting_queue = deque(
|
||||||
|
t for t in self.waiting_queue if t.task_id != task_id
|
||||||
|
)
|
||||||
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
return removed_active
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"total_tasks": self._total_tasks,
|
||||||
|
"total_tokens": self._total_tokens,
|
||||||
|
"active_tasks": len(self.active_tasks),
|
||||||
|
"waiting_queue": len(self.waiting_queue),
|
||||||
|
}
|
||||||
|
|
||||||
|
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
finished = []
|
||||||
|
for task in self.active_tasks:
|
||||||
|
if task.status == TaskStatus.ABORTED:
|
||||||
|
task.finish_time = time.time()
|
||||||
|
finished.append(task)
|
||||||
|
elif task.is_finished(stop_ids):
|
||||||
|
task.status = TaskStatus.FINISHED
|
||||||
|
task.finish_time = time.time()
|
||||||
|
finished.append(task)
|
||||||
|
self._total_tokens += task.output_tokens
|
||||||
|
|
||||||
|
self.active_tasks = [
|
||||||
|
t
|
||||||
|
for t in self.active_tasks
|
||||||
|
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
||||||
|
]
|
||||||
|
return finished
|
||||||
|
|
||||||
|
def pull_candidates(self, n: int) -> List[Task]:
|
||||||
|
to_add: List[Task] = []
|
||||||
|
with self._lock:
|
||||||
|
take = min(n, len(self.waiting_queue))
|
||||||
|
for _ in range(take):
|
||||||
|
to_add.append(self.waiting_queue.popleft())
|
||||||
|
return to_add
|
||||||
|
|
||||||
|
def activate(self, task: Task):
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
with self._lock:
|
||||||
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
|
def return_to_waiting(self, tasks: List[Task]):
|
||||||
|
with self._lock:
|
||||||
|
for task in reversed(tasks):
|
||||||
|
self.waiting_queue.appendleft(task)
|
||||||
|
|
||||||
|
def has_work(self) -> bool:
|
||||||
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
||||||
|
def wait_for_tasks(self, timeout: float = 1.0):
|
||||||
|
with self._lock:
|
||||||
|
if self.waiting_queue or self.active_tasks:
|
||||||
|
return
|
||||||
|
self._task_event.clear()
|
||||||
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
def get_active_tasks(self) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
return list(self.active_tasks)
|
||||||
|
|
||||||
|
def get_waiting_tasks(self) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
return list(self.waiting_queue)
|
||||||
|
|
||||||
|
def clear_queues(self):
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue.clear()
|
||||||
|
self.active_tasks.clear()
|
||||||
|
|
||||||
|
def wake(self):
|
||||||
|
self._task_event.set()
|
||||||
|
|
@ -0,0 +1,288 @@
|
||||||
|
"""Unified inference engine for continuous batching."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import gc
|
||||||
|
import threading
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from astrai.inference.core.scheduler import InferenceScheduler
|
||||||
|
from astrai.inference.core.task import STOP
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResult:
|
||||||
|
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
def __init__(self, count: int = 1):
|
||||||
|
self._cond = threading.Condition()
|
||||||
|
self._event = threading.Event()
|
||||||
|
self.tokens: List[Tuple[int, str]] = []
|
||||||
|
self.results: List[str] = [""] * count
|
||||||
|
self._done: List[bool] = [False] * count
|
||||||
|
self._completed = 0
|
||||||
|
self._total = count
|
||||||
|
|
||||||
|
def append(self, token: str, idx: int = 0):
|
||||||
|
with self._cond:
|
||||||
|
self.tokens.append((idx, token))
|
||||||
|
if token is not STOP:
|
||||||
|
self.results[idx] += token
|
||||||
|
else:
|
||||||
|
if not self._done[idx]:
|
||||||
|
self._done[idx] = True
|
||||||
|
self._completed += 1
|
||||||
|
self._cond.notify_all()
|
||||||
|
self._event.set()
|
||||||
|
|
||||||
|
def pop_all(self) -> List[Tuple[int, str]]:
|
||||||
|
with self._cond:
|
||||||
|
out = self.tokens.copy()
|
||||||
|
self.tokens.clear()
|
||||||
|
if not out:
|
||||||
|
self._event.clear()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
def wait_completion(self, timeout: float = 300.0):
|
||||||
|
with self._cond:
|
||||||
|
if not self._cond.wait_for(
|
||||||
|
lambda: self._completed >= self._total, timeout=timeout
|
||||||
|
):
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Generation timeout after {timeout}s "
|
||||||
|
f"({self._completed}/{self._total} completed)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_results(self) -> List[str]:
|
||||||
|
with self._cond:
|
||||||
|
return self.results.copy()
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationRequest:
|
||||||
|
"""Request parameters for text generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
top_k: int = 50,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
if not (isinstance(top_k, int) and top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
||||||
|
raise ValueError("temperature must be a positive number")
|
||||||
|
|
||||||
|
self.messages = messages
|
||||||
|
self.top_k = top_k
|
||||||
|
self.top_p = top_p
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.stream = stream
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceEngine:
|
||||||
|
"""Unified inference engine backed by continuous-batching scheduler."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
max_batch_size: int = 1,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
max_prompt_len: int = 2048,
|
||||||
|
page_size: int = 128,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.scheduler = InferenceScheduler(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scheduler.start()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.shutdown()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> Union[Generator, str, List[str]]:
|
||||||
|
is_batch = isinstance(prompt, list)
|
||||||
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._generate_streaming(
|
||||||
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._generate_non_streaming(
|
||||||
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
sync_gen = self._generate_streaming(
|
||||||
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agen():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
while True:
|
||||||
|
token = await loop.run_in_executor(None, self._next_token, sync_gen)
|
||||||
|
if token is None:
|
||||||
|
break
|
||||||
|
yield token
|
||||||
|
|
||||||
|
return _agen()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _next_token(gen: Generator) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
return next(gen)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def generate_with_request(
|
||||||
|
self, request: GenerationRequest
|
||||||
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
|
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||||
|
return self.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
stream=request.stream,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
top_k=request.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _submit_tasks(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
) -> Tuple[GenerateResult, List[str]]:
|
||||||
|
n = len(prompts)
|
||||||
|
result = GenerateResult(count=n)
|
||||||
|
task_ids = []
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
|
cb = self._make_callback(result, i)
|
||||||
|
task_id = self.scheduler.add_task(
|
||||||
|
prompt=p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=cb,
|
||||||
|
)
|
||||||
|
task_ids.append(task_id)
|
||||||
|
return result, task_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_callback(result: GenerateResult, idx: int):
|
||||||
|
def cb(token):
|
||||||
|
result.append(token, idx)
|
||||||
|
|
||||||
|
return cb
|
||||||
|
|
||||||
|
def _generate_streaming(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
is_batch: bool,
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
) -> Generator:
|
||||||
|
result, task_ids = self._submit_tasks(
|
||||||
|
prompts, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
n = len(prompts)
|
||||||
|
remaining = n
|
||||||
|
finished = [False] * n
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
nonlocal remaining
|
||||||
|
try:
|
||||||
|
while remaining > 0:
|
||||||
|
items = result.pop_all()
|
||||||
|
for idx, token in items:
|
||||||
|
if token is STOP:
|
||||||
|
if not finished[idx]:
|
||||||
|
finished[idx] = True
|
||||||
|
remaining -= 1
|
||||||
|
else:
|
||||||
|
yield (idx, token) if is_batch else token
|
||||||
|
if remaining > 0:
|
||||||
|
result.wait(timeout=0.05)
|
||||||
|
finally:
|
||||||
|
for tid in task_ids:
|
||||||
|
self.scheduler.remove_task(tid)
|
||||||
|
|
||||||
|
return gen()
|
||||||
|
|
||||||
|
def _generate_non_streaming(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
is_batch: bool,
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
) -> Union[str, List[str]]:
|
||||||
|
result, task_ids = self._submit_tasks(
|
||||||
|
prompts, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result.wait_completion()
|
||||||
|
except TimeoutError:
|
||||||
|
for tid in task_ids:
|
||||||
|
self.scheduler.remove_task(tid)
|
||||||
|
raise
|
||||||
|
|
||||||
|
for tid in task_ids:
|
||||||
|
self.scheduler.remove_task(tid)
|
||||||
|
|
||||||
|
res = result.get_results()
|
||||||
|
return res if is_batch else res[0]
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.scheduler.stop()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
@ -0,0 +1,190 @@
|
||||||
|
"""Composable sampling strategies for logit transformation.
|
||||||
|
|
||||||
|
Implements the Strategy pattern: each sampling technique
|
||||||
|
(temperature, top-k, top-p) is a pluggable strategy that
|
||||||
|
can be composed into a pipeline.
|
||||||
|
|
||||||
|
All strategies accept both scalar and per-sample tensor
|
||||||
|
parameters, so a single pipeline works for any batch size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSamplingStrategy(ABC):
|
||||||
|
"""Abstract base for a logit transformation strategy."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
|
"""Applies the strategy to logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits tensor (batch, vocab_size).
|
||||||
|
filter_value: Value assigned to filtered-out positions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed logits tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
|
"""Divides logits by temperature to control randomness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature: Scalar or ``[batch]`` tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
t = self.temperature
|
||||||
|
if isinstance(t, Tensor):
|
||||||
|
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||||
|
t = torch.clamp(t, min=1e-8)
|
||||||
|
if (t != 1.0).any():
|
||||||
|
logits = logits / t
|
||||||
|
elif t != 1.0:
|
||||||
|
logits = logits / max(t, 1e-8)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class TopKStrategy(BaseSamplingStrategy):
|
||||||
|
"""Keeps only the top-k logits, setting the rest to filter_value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k: Scalar or ``[batch]`` tensor (0 disables).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_k: Union[int, Tensor] = 0):
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
tk = self.top_k
|
||||||
|
if isinstance(tk, Tensor):
|
||||||
|
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
||||||
|
max_k = int(tk.max().item())
|
||||||
|
if max_k <= 0:
|
||||||
|
return logits
|
||||||
|
max_k = min(max_k, logits.size(-1))
|
||||||
|
values, _ = torch.topk(logits, max_k, dim=-1)
|
||||||
|
per_row_k = tk.clamp(max=max_k)
|
||||||
|
thresholds = torch.full_like(logits[..., -1:], -float("inf"))
|
||||||
|
positive = per_row_k > 0
|
||||||
|
if positive.any():
|
||||||
|
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
|
||||||
|
thresholds[positive] = values[
|
||||||
|
row_idx, per_row_k[positive] - 1
|
||||||
|
].unsqueeze(-1)
|
||||||
|
logits[logits < thresholds] = filter_value
|
||||||
|
return logits
|
||||||
|
if tk > 0:
|
||||||
|
k = min(tk, logits.size(-1))
|
||||||
|
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
||||||
|
logits[logits < thresholds] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class TopPStrategy(BaseSamplingStrategy):
|
||||||
|
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
||||||
|
cumulative probability exceeds top_p.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
||||||
|
self.top_p = top_p
|
||||||
|
|
||||||
|
def _apply(self, logits, top_p, filter_value):
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
|
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
remove = cum_probs > top_p
|
||||||
|
remove[..., 1:] = remove[..., :-1].clone()
|
||||||
|
remove[..., 0] = False
|
||||||
|
mask = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
mask.scatter_(1, sorted_indices, remove)
|
||||||
|
logits[mask] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
tp = self.top_p
|
||||||
|
if isinstance(tp, Tensor):
|
||||||
|
tp = tp.to(logits.device, non_blocking=True)
|
||||||
|
if (tp < 1.0).any():
|
||||||
|
logits = self._apply(logits, tp.view(-1, 1), filter_value)
|
||||||
|
elif tp < 1.0:
|
||||||
|
logits = self._apply(logits, tp, filter_value)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingPipeline(BaseSamplingStrategy):
|
||||||
|
"""Composes multiple sampling strategies into a single transformation.
|
||||||
|
|
||||||
|
Strategies are applied sequentially in the order they are provided,
|
||||||
|
matching the original temperature -> top-k -> top-p ordering.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
pipeline = SamplingPipeline([
|
||||||
|
TemperatureStrategy(0.8),
|
||||||
|
TopKStrategy(50),
|
||||||
|
TopPStrategy(0.95),
|
||||||
|
])
|
||||||
|
logits = pipeline.apply(logits)
|
||||||
|
token = pipeline.sample(logits) # softmax + multinomial
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||||
|
self.strategies = strategies
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
for strategy in self.strategies:
|
||||||
|
logits = strategy.apply(logits, filter_value)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
|
"""Apply strategies then sample (softmax + multinomial).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits ``[batch, vocab_size]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sampled token IDs ``[batch]``.
|
||||||
|
"""
|
||||||
|
return torch.multinomial(
|
||||||
|
torch.softmax(self.apply(logits, filter_value), dim=-1),
|
||||||
|
num_samples=1,
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sample(
|
||||||
|
logits: Tensor,
|
||||||
|
temperature: Union[float, Tensor] = 1.0,
|
||||||
|
top_k: Union[int, Tensor] = 0,
|
||||||
|
top_p: Union[float, Tensor] = 1.0,
|
||||||
|
filter_value: float = -float("inf"),
|
||||||
|
) -> Tensor:
|
||||||
|
"""Apply sampling strategies then sample (softmax + multinomial).
|
||||||
|
|
||||||
|
Shortcut for ``SamplingPipeline(...).sample(logits)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits ``[batch, vocab_size]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sampled token IDs ``[batch]``.
|
||||||
|
"""
|
||||||
|
return SamplingPipeline(
|
||||||
|
[
|
||||||
|
TemperatureStrategy(temperature),
|
||||||
|
TopKStrategy(top_k),
|
||||||
|
TopPStrategy(top_p),
|
||||||
|
]
|
||||||
|
).sample(logits, filter_value)
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.model.components.attention import GQA
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.lora import (
|
||||||
|
LoRAConfig,
|
||||||
|
inject_lora,
|
||||||
|
load_lora,
|
||||||
|
merge_lora,
|
||||||
|
save_lora,
|
||||||
|
)
|
||||||
|
from astrai.model.components.mlp import MLP
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Modules
|
||||||
|
"Linear",
|
||||||
|
"RMSNorm",
|
||||||
|
"MLP",
|
||||||
|
"GQA",
|
||||||
|
"DecoderBlock",
|
||||||
|
# Models
|
||||||
|
"AutoRegressiveLM",
|
||||||
|
"EmbeddingEncoder",
|
||||||
|
"AutoModel",
|
||||||
|
# LoRA
|
||||||
|
"LoRAConfig",
|
||||||
|
"inject_lora",
|
||||||
|
"merge_lora",
|
||||||
|
"save_lora",
|
||||||
|
"load_lora",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
AutoModel base class for model loading and saving.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Self, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.serialization import load_model_config, load_model_weights, save_model
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _disable_random_init(enable: bool = True):
|
||||||
|
if not enable:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
names = (
|
||||||
|
"xavier_normal_",
|
||||||
|
"xavier_uniform_",
|
||||||
|
"kaiming_normal_",
|
||||||
|
"kaiming_uniform_",
|
||||||
|
"zeros_",
|
||||||
|
"ones_",
|
||||||
|
"constant_",
|
||||||
|
"normal_",
|
||||||
|
"uniform_",
|
||||||
|
)
|
||||||
|
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
||||||
|
for n in orig:
|
||||||
|
setattr(nn.init, n, lambda *a, **kw: None)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for n, fn in orig.items():
|
||||||
|
setattr(nn.init, n, fn)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
|
"""
|
||||||
|
Autoregressive language model base class.
|
||||||
|
Provides model loading/saving, registration, and generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: BaseModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
path: Union[str, Path],
|
||||||
|
disable_random_init: bool = True,
|
||||||
|
strict: bool = True,
|
||||||
|
) -> nn.Module:
|
||||||
|
|
||||||
|
model_path = Path(path)
|
||||||
|
|
||||||
|
config_path = model_path / "config.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
raw = load_model_config(str(model_path))
|
||||||
|
config = ConfigFactory.load(raw)
|
||||||
|
model_type = config.model_type or "autoregressive_lm"
|
||||||
|
|
||||||
|
actual_cls = AutoModel.get_component_class(model_type)
|
||||||
|
|
||||||
|
with _disable_random_init(enable=disable_random_init):
|
||||||
|
model = actual_cls(config)
|
||||||
|
|
||||||
|
weights_path = model_path / "model.safetensors"
|
||||||
|
if weights_path.exists():
|
||||||
|
state_dict = load_model_weights(str(model_path))
|
||||||
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, Path],
|
||||||
|
):
|
||||||
|
save_model(
|
||||||
|
config=self.config.to_dict(),
|
||||||
|
state_dict=self.state_dict(),
|
||||||
|
save_directory=str(save_directory),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs) -> Self:
|
||||||
|
"""Move model to device/dtype."""
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.mlp import MLP
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import (
|
||||||
|
RotaryEmbedding,
|
||||||
|
apply_rotary_emb,
|
||||||
|
get_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Linear",
|
||||||
|
"RMSNorm",
|
||||||
|
"MLP",
|
||||||
|
"Embedding",
|
||||||
|
"GQA",
|
||||||
|
"MLA",
|
||||||
|
"DecoderBlock",
|
||||||
|
"RotaryEmbedding",
|
||||||
|
"apply_rotary_emb",
|
||||||
|
"get_rotary_emb",
|
||||||
|
"repeat_kv",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,212 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
|
bs, slen, n_heads, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, :, None, :]
|
||||||
|
.expand(bs, slen, n_heads, n_rep, head_dim)
|
||||||
|
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnFactory(BaseFactory[nn.Module]):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
||||||
|
return super().create(attn_type, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@AttnFactory.register("gqa")
|
||||||
|
class GQA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
norm_eps: float,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % n_heads == 0
|
||||||
|
assert n_heads % n_kv_heads == 0
|
||||||
|
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.n_rep = n_heads // n_kv_heads
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
|
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
||||||
|
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
|
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
|
self.o_proj = Linear(dim, dim)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
|
||||||
|
if self.use_gated_attention:
|
||||||
|
self.gate = Linear(dim, dim)
|
||||||
|
|
||||||
|
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rotary_emb: Tensor,
|
||||||
|
attn_mask: Tensor = None,
|
||||||
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
is_causal = attn_mask is None
|
||||||
|
|
||||||
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||||
|
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if paged_cache is not None:
|
||||||
|
paged_cache.write(self.layer_id, k, v)
|
||||||
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
|
|
||||||
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
|
sdqa_out = (
|
||||||
|
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.contiguous()
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_gated_attention:
|
||||||
|
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||||
|
|
||||||
|
out = self.o_proj(sdqa_out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@AttnFactory.register("mla")
|
||||||
|
class MLA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
kv_lora_rank: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
norm_eps: float,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.n_rep = n_heads // n_kv_heads
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
|
self.kv_b_proj = Linear(
|
||||||
|
kv_lora_rank,
|
||||||
|
n_kv_heads * (2 * self.head_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
if use_gated_attention:
|
||||||
|
self.gate = Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rotary_emb: Tensor,
|
||||||
|
attn_mask: Tensor = None,
|
||||||
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
bsz, seq_len, _ = x.size()
|
||||||
|
is_causal = attn_mask is None
|
||||||
|
|
||||||
|
q = self.q_proj(x)
|
||||||
|
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
|
kv_compressed = self.kv_a_proj(x)
|
||||||
|
kv_compressed = self.kv_norm(kv_compressed)
|
||||||
|
|
||||||
|
kv = self.kv_b_proj(kv_compressed)
|
||||||
|
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
||||||
|
|
||||||
|
k_nope, k_rope, v = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
q_nope, q_rope = (
|
||||||
|
q[..., : self.qk_nope_head_dim],
|
||||||
|
q[..., self.qk_nope_head_dim :],
|
||||||
|
)
|
||||||
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||||
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||||
|
|
||||||
|
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
if paged_cache is not None:
|
||||||
|
paged_cache.write(self.layer_id, k, v)
|
||||||
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
|
|
||||||
|
q = q.permute(0, 2, 1, 3)
|
||||||
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
attn_out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask, is_causal=is_causal
|
||||||
|
)
|
||||||
|
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||||
|
|
||||||
|
if self.use_gated_attention:
|
||||||
|
attn_out = attn_out * F.sigmoid(self.gate(x))
|
||||||
|
|
||||||
|
out = self.o_proj(attn_out)
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.attention import AttnFactory
|
||||||
|
from astrai.model.components.mlp import FFNFactory
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int,
|
||||||
|
attn_type: str = "gqa",
|
||||||
|
ffn_type: str = "mlp",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = AttnFactory.create(
|
||||||
|
attn_type,
|
||||||
|
dim=dim,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
use_qk_norm=use_qk_norm,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
use_gated_attention=use_gated_attention,
|
||||||
|
layer_id=layer_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
|
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rotary_emb: Tensor,
|
||||||
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
attn_output = self.attention(
|
||||||
|
self.input_norm(x),
|
||||||
|
rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
paged_cache,
|
||||||
|
)
|
||||||
|
x = attn_output + x
|
||||||
|
x = self.mlp(self.post_attention_norm(x)) + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Module):
|
||||||
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.embedding(x, self.weight)
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||||
|
bound = 1 / (fan_in**0.5)
|
||||||
|
nn.init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
@ -0,0 +1,194 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.serialization import (
|
||||||
|
load_json,
|
||||||
|
load_safetensors,
|
||||||
|
save_json,
|
||||||
|
save_safetensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
||||||
|
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
r: int = 16
|
||||||
|
alpha: int = 32
|
||||||
|
target_modules: tuple = ("q_proj", "v_proj")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
self.register_parameter("weight", base.weight)
|
||||||
|
self.weight.requires_grad_(False)
|
||||||
|
self.bias = base.bias
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.requires_grad_(False)
|
||||||
|
|
||||||
|
self.r = r
|
||||||
|
self.scaling = alpha / r
|
||||||
|
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
||||||
|
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
||||||
|
self._merged = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.linear(x, self.weight, self.bias)
|
||||||
|
if not self._merged:
|
||||||
|
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
||||||
|
return out
|
||||||
|
|
||||||
|
def merge(self):
|
||||||
|
if self._merged:
|
||||||
|
return
|
||||||
|
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
self._merged = True
|
||||||
|
del self.lora_A
|
||||||
|
del self.lora_B
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_lora_info(model: nn.Module) -> dict:
|
||||||
|
names = {}
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, Linear):
|
||||||
|
_, _, child = n.rpartition(".")
|
||||||
|
names.setdefault(child, []).append(n)
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_count(model: nn.Module) -> int:
|
||||||
|
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
||||||
|
|
||||||
|
|
||||||
|
def inject_lora(
|
||||||
|
model: nn.Module,
|
||||||
|
r: int = 16,
|
||||||
|
alpha: int = 32,
|
||||||
|
target_modules: Optional[Set[str]] = None,
|
||||||
|
) -> LoRAConfig:
|
||||||
|
if target_modules is None:
|
||||||
|
target_modules = TARGET_MODULES_ATTN
|
||||||
|
|
||||||
|
available = _collect_lora_info(model)
|
||||||
|
injected = 0
|
||||||
|
|
||||||
|
for name, module in list(model.named_modules()):
|
||||||
|
if not isinstance(module, Linear):
|
||||||
|
continue
|
||||||
|
parent_name, _, child_name = name.rpartition(".")
|
||||||
|
if child_name not in target_modules:
|
||||||
|
continue
|
||||||
|
parent = model.get_submodule(parent_name) if parent_name else model
|
||||||
|
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
||||||
|
injected += 1
|
||||||
|
|
||||||
|
if injected == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No LoRA layers injected. Available Linear child names: %s. "
|
||||||
|
"target_modules: %s. Check model type and target_modules.",
|
||||||
|
sorted(available),
|
||||||
|
sorted(target_modules),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
||||||
|
|
||||||
|
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora(model: nn.Module):
|
||||||
|
n = 0
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, LoRALinear):
|
||||||
|
module.merge()
|
||||||
|
n += 1
|
||||||
|
if n == 0:
|
||||||
|
logger.warning("No LoRA layers to merge.")
|
||||||
|
else:
|
||||||
|
logger.info("Merged %d LoRA layers", n)
|
||||||
|
|
||||||
|
|
||||||
|
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
||||||
|
lora_sd = {
|
||||||
|
k: v
|
||||||
|
for k, v in model.state_dict().items()
|
||||||
|
if k.endswith((".lora_A", ".lora_B"))
|
||||||
|
}
|
||||||
|
if not lora_sd:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No LoRA parameters found in model. "
|
||||||
|
"The model may not have been injected or was already merged."
|
||||||
|
)
|
||||||
|
|
||||||
|
path = Path(save_dir)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
||||||
|
save_json(asdict(config), path / "adapter_config.json")
|
||||||
|
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||||
|
path = Path(load_dir)
|
||||||
|
raw = load_json(path / "adapter_config.json")
|
||||||
|
config = LoRAConfig(
|
||||||
|
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = _get_lora_count(model)
|
||||||
|
if existing > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Model already has %d LoRA layers. Skipping injection, "
|
||||||
|
"loading weights onto existing layers only.",
|
||||||
|
existing,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inject_lora(
|
||||||
|
model,
|
||||||
|
r=config.r,
|
||||||
|
alpha=config.alpha,
|
||||||
|
target_modules=set(config.target_modules),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = load_safetensors(path / "adapter_model.safetensors")
|
||||||
|
try:
|
||||||
|
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||||
|
except RuntimeError as e:
|
||||||
|
msg = str(e)
|
||||||
|
if "size mismatch" in msg:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LoRA weight shapes do not match the model. "
|
||||||
|
f"The adapter config (r={config.r}) may not match the injected layers. "
|
||||||
|
f"Original error: {msg}"
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
injected = _get_lora_count(model)
|
||||||
|
if injected == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No LoRA layers found after loading. "
|
||||||
|
"Inject LoRA before calling load_lora, or check the adapter config."
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
lora_missing = [k for k in missing if "lora" in k]
|
||||||
|
if lora_missing:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LoRA weight keys not found in model: {lora_missing}. "
|
||||||
|
f"The adapter config (r={config.r}) may not match the model."
|
||||||
|
)
|
||||||
|
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
||||||
|
if unexpected:
|
||||||
|
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
||||||
|
|
||||||
|
logger.info("LoRA adapter loaded from %s", load_dir)
|
||||||
|
return config
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
|
||||||
|
|
||||||
|
class FFNFactory(BaseFactory[nn.Module]):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
||||||
|
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@FFNFactory.register("mlp")
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim: int, dim_ffn: int):
|
||||||
|
super().__init__()
|
||||||
|
self.up = Linear(dim, dim_ffn)
|
||||||
|
self.gate = Linear(dim, dim_ffn)
|
||||||
|
self.down = Linear(dim_ffn, dim)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
|
out = self.down(gated)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@FFNFactory.register("moe")
|
||||||
|
class DeepSeekMoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
n_routed_experts: int,
|
||||||
|
n_shared_experts: int = 1,
|
||||||
|
n_activated_experts: int = 2,
|
||||||
|
topk_method: str = "greedy",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_activated_experts = n_activated_experts
|
||||||
|
self.topk_method = topk_method
|
||||||
|
|
||||||
|
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||||
|
|
||||||
|
self.shared_experts = nn.ModuleList(
|
||||||
|
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
||||||
|
)
|
||||||
|
self.routed_experts = nn.ModuleList(
|
||||||
|
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
bsz, seq_len, dim = x.shape
|
||||||
|
x_flat = x.view(-1, dim)
|
||||||
|
|
||||||
|
shared_out = self._shared_forward(x_flat)
|
||||||
|
routed_out = self._routed_forward(x_flat)
|
||||||
|
|
||||||
|
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _shared_forward(self, x: Tensor) -> Tensor:
|
||||||
|
if self.n_shared_experts == 0:
|
||||||
|
return torch.zeros_like(x)
|
||||||
|
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
||||||
|
|
||||||
|
def _routed_forward(self, x: Tensor) -> Tensor:
|
||||||
|
N, D = x.shape
|
||||||
|
K = self.n_activated_experts
|
||||||
|
|
||||||
|
router_logits = self.router(x)
|
||||||
|
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
||||||
|
for expert_idx in range(self.n_routed_experts):
|
||||||
|
expert_mask = topk_indices == expert_idx
|
||||||
|
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
||||||
|
if token_idx.numel() == 0:
|
||||||
|
continue
|
||||||
|
expert_input = x[token_idx]
|
||||||
|
expert_output = self.routed_experts[expert_idx](expert_input)
|
||||||
|
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
||||||
|
output.index_add_(0, token_idx, expert_output * weights)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, norm_eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.normalized_shape = (dim,)
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotary_emb(
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
|
freqs = torch.outer(t, theta).float()
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def ntk_base(base: float, dim: int, factor: float) -> float:
|
||||||
|
return base * (factor ** (dim / (dim - 2)))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
dtype = x.dtype
|
||||||
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||||
|
x_complex = torch.view_as_complex(x_)
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
x_rotated = x_complex * freqs_cis
|
||||||
|
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||||
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.max_len = max_len
|
||||||
|
self.base = base
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
if rope_scaling is not None:
|
||||||
|
scaling_type = rope_scaling.get("type", "ntk")
|
||||||
|
factor = rope_scaling.get("factor", 1.0)
|
||||||
|
if scaling_type == "ntk":
|
||||||
|
self.base = ntk_base(base, dim, factor)
|
||||||
|
|
||||||
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
|
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||||
|
freqs_cis = torch.view_as_real(rotary_emb)
|
||||||
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(x.size(1), device=x.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(x.size(0), -1)
|
||||||
|
)
|
||||||
|
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||||
|
return torch.view_as_complex(position_freq_cis)
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.config.model_config import EncoderConfig
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import RotaryEmbedding
|
||||||
|
from astrai.model.transformer import process_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
@AutoModel.register("embedding")
|
||||||
|
class EmbeddingEncoder(AutoModel):
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
rope_dim = config.dim // config.n_heads
|
||||||
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
|
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||||
|
)
|
||||||
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DecoderBlock(
|
||||||
|
config.dim,
|
||||||
|
config.n_heads,
|
||||||
|
config.dim_ffn,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.norm_eps,
|
||||||
|
config.use_qk_norm,
|
||||||
|
config.use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
|
|
||||||
|
self.pooling_type = config.pooling_type or "mean"
|
||||||
|
self.normalize_embeddings = config.normalize_embeddings or False
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if hasattr(module, "reset_parameters"):
|
||||||
|
module.reset_parameters()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
state_dict.pop("lm_head.weight", None)
|
||||||
|
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
input_mask: Optional[Tensor] = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
B, S = input_ids.shape
|
||||||
|
|
||||||
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||||
|
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
|
||||||
|
|
||||||
|
hidden_states = self.norm(x)
|
||||||
|
|
||||||
|
if self.pooling_type == "cls":
|
||||||
|
pooled = hidden_states[:, 0]
|
||||||
|
elif self.pooling_type == "last":
|
||||||
|
if input_mask is not None:
|
||||||
|
lengths = input_mask.sum(dim=1) - 1
|
||||||
|
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
|
||||||
|
else:
|
||||||
|
pooled = hidden_states[:, -1]
|
||||||
|
else:
|
||||||
|
if input_mask is not None:
|
||||||
|
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
||||||
|
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
|
||||||
|
min=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pooled = hidden_states.mean(dim=1)
|
||||||
|
|
||||||
|
if self.normalize_embeddings:
|
||||||
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
|
||||||
|
|
||||||
|
return pooled
|
||||||
|
|
@ -0,0 +1,152 @@
|
||||||
|
from typing import Any, Dict, Mapping, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
def process_attention_mask(
|
||||||
|
input_tensor: Tensor,
|
||||||
|
position_ids: Optional[Tensor],
|
||||||
|
input_mask: Optional[Tensor] = None,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Optional[Tensor]:
|
||||||
|
if position_ids is None:
|
||||||
|
return None
|
||||||
|
if input_mask is not None and input_mask.dim() > 2:
|
||||||
|
return input_mask
|
||||||
|
|
||||||
|
device = input_tensor.device
|
||||||
|
dtype = input_tensor.dtype
|
||||||
|
B, S = input_tensor.size()[:2]
|
||||||
|
T = position_ids.max().item() + 1
|
||||||
|
|
||||||
|
if input_mask is None:
|
||||||
|
if position_ids.min().item() == 0 and is_causal:
|
||||||
|
return None
|
||||||
|
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
||||||
|
if is_causal:
|
||||||
|
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||||
|
|
||||||
|
return torch.full(
|
||||||
|
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||||
|
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@AutoModel.register("autoregressive_lm")
|
||||||
|
class AutoRegressiveLM(AutoModel):
|
||||||
|
"""Autoregressive language model with paged KV cache."""
|
||||||
|
|
||||||
|
def __init__(self, config: AutoRegressiveLMConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
rope_dim = (
|
||||||
|
config.qk_rope_head_dim
|
||||||
|
if config.attn_type == "mla"
|
||||||
|
else config.dim // config.n_heads
|
||||||
|
)
|
||||||
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
|
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||||
|
)
|
||||||
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DecoderBlock(
|
||||||
|
config.dim,
|
||||||
|
config.n_heads,
|
||||||
|
config.dim_ffn,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.norm_eps,
|
||||||
|
config.use_qk_norm,
|
||||||
|
config.use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
attn_type=config.attn_type,
|
||||||
|
ffn_type=config.ffn_type,
|
||||||
|
n_routed_experts=config.n_routed_experts,
|
||||||
|
n_shared_experts=config.n_shared_experts,
|
||||||
|
n_activated_experts=config.n_activated_experts,
|
||||||
|
topk_method=config.topk_method,
|
||||||
|
kv_lora_rank=config.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
|
if self.config.tie_weight is True:
|
||||||
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if hasattr(module, "reset_parameters"):
|
||||||
|
module.reset_parameters()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
|
lm_head_key = "lm_head.weight"
|
||||||
|
embed_key = "embed_tokens.weight"
|
||||||
|
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
|
if self.config.tie_weight is True:
|
||||||
|
# same tensor for embed and lm_head
|
||||||
|
if embed_key in state_dict:
|
||||||
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
|
else:
|
||||||
|
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||||
|
# clone to avoid sharing gradients
|
||||||
|
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||||
|
|
||||||
|
return super().load_state_dict(state_dict, strict, assign)
|
||||||
|
|
||||||
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
|
state_dict = super().state_dict(
|
||||||
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.tie_weight is True:
|
||||||
|
lm_head_key = prefix + "lm_head.weight"
|
||||||
|
if lm_head_key in state_dict:
|
||||||
|
del state_dict[lm_head_key]
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
input_mask: Optional[Tensor] = None,
|
||||||
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
|
) -> Dict[str, Tensor]:
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
|
x = self.embed_tokens(input_ids)
|
||||||
|
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||||
|
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, rotary_emb, attn_mask, paged_cache)
|
||||||
|
|
||||||
|
hidden_states = self.norm(x)
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
return {"logits": logits, "hidden_states": hidden_states}
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
from astrai.parallel.executor import (
|
||||||
|
AccumOptimizer,
|
||||||
|
AccumScheduler,
|
||||||
|
BaseExecutor,
|
||||||
|
DDPExecutor,
|
||||||
|
ExecutorFactory,
|
||||||
|
FSDPExecutor,
|
||||||
|
GradientState,
|
||||||
|
NoneExecutor,
|
||||||
|
)
|
||||||
|
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||||
|
from astrai.parallel.setup import (
|
||||||
|
get_current_device,
|
||||||
|
get_rank,
|
||||||
|
get_world_size,
|
||||||
|
only_on_rank,
|
||||||
|
setup_parallel,
|
||||||
|
spawn_parallel_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_world_size",
|
||||||
|
"get_rank",
|
||||||
|
"get_current_device",
|
||||||
|
"only_on_rank",
|
||||||
|
"setup_parallel",
|
||||||
|
"spawn_parallel_fn",
|
||||||
|
"RowParallelLinear",
|
||||||
|
"ColumnParallelLinear",
|
||||||
|
"ExecutorFactory",
|
||||||
|
"BaseExecutor",
|
||||||
|
"GradientState",
|
||||||
|
"AccumOptimizer",
|
||||||
|
"AccumScheduler",
|
||||||
|
"NoneExecutor",
|
||||||
|
"DDPExecutor",
|
||||||
|
"FSDPExecutor",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,271 @@
|
||||||
|
"""Unified training executor — parallel strategy + gradient accumulation."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.parallel.setup import get_rank, get_world_size
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientState:
|
||||||
|
def __init__(self, grad_accum_steps: int = 1):
|
||||||
|
self.num_steps = max(grad_accum_steps, 1)
|
||||||
|
self._step: int = 0
|
||||||
|
self._sync_gradients: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_gradients(self) -> bool:
|
||||||
|
return self._sync_gradients
|
||||||
|
|
||||||
|
def _do_sync(self):
|
||||||
|
self._step += 1
|
||||||
|
self._sync_gradients = self._step % self.num_steps == 0
|
||||||
|
|
||||||
|
|
||||||
|
class AccumOptimizer:
|
||||||
|
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.gradient_state = gradient_state
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
if self.gradient_state.sync_gradients:
|
||||||
|
self.optimizer.step(closure)
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
if self.gradient_state.sync_gradients:
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.optimizer.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, d):
|
||||||
|
self.optimizer.load_state_dict(d)
|
||||||
|
|
||||||
|
|
||||||
|
class AccumScheduler:
|
||||||
|
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.gradient_state = gradient_state
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
if self.gradient_state.sync_gradients:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.scheduler.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, d):
|
||||||
|
self.scheduler.load_state_dict(d)
|
||||||
|
|
||||||
|
def get_last_lr(self):
|
||||||
|
return self.scheduler.get_last_lr()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExecutor:
|
||||||
|
def __init__(self, grad_accum_steps: int = 1):
|
||||||
|
self.gradient_state = GradientState(grad_accum_steps)
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
dataloader: Optional[DataLoader] = None,
|
||||||
|
scheduler: Optional[LRScheduler] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
|
||||||
|
]:
|
||||||
|
model = self._prepare_model(model)
|
||||||
|
if optimizer is not None:
|
||||||
|
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
||||||
|
return model, optimizer, dataloader, scheduler
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def accumulate(self, model: nn.Module):
|
||||||
|
self.gradient_state._do_sync()
|
||||||
|
if not self.gradient_state.sync_gradients:
|
||||||
|
with self._no_sync(model):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
def backward(self, loss: torch.Tensor):
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module):
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_distributed(self) -> bool:
|
||||||
|
return get_world_size() > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_gradients(self) -> bool:
|
||||||
|
return self.gradient_state.sync_gradients
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grad_accum_steps(self) -> int:
|
||||||
|
return self.gradient_state.num_steps
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ExecutorFactory.register("none")
|
||||||
|
class NoneExecutor(BaseExecutor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ExecutorFactory.register("ddp")
|
||||||
|
class DDPExecutor(BaseExecutor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grad_accum_steps: int = 1,
|
||||||
|
dim: int = 0,
|
||||||
|
broadcast_buffers: bool = True,
|
||||||
|
init_sync: bool = True,
|
||||||
|
process_group=None,
|
||||||
|
bucket_cap_mb: int = 25,
|
||||||
|
find_unused_parameters: bool = False,
|
||||||
|
check_reduction: bool = False,
|
||||||
|
gradient_as_bucket_view: bool = False,
|
||||||
|
static_graph: bool = False,
|
||||||
|
delay_all_reduce_named_params=None,
|
||||||
|
param_to_hook_all_reduce=None,
|
||||||
|
mixed_precision=None,
|
||||||
|
device_mesh=None,
|
||||||
|
):
|
||||||
|
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||||
|
self._ddp_kwargs = dict(
|
||||||
|
dim=dim,
|
||||||
|
broadcast_buffers=broadcast_buffers,
|
||||||
|
init_sync=init_sync,
|
||||||
|
process_group=process_group,
|
||||||
|
bucket_cap_mb=bucket_cap_mb,
|
||||||
|
find_unused_parameters=find_unused_parameters,
|
||||||
|
check_reduction=check_reduction,
|
||||||
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||||
|
static_graph=static_graph,
|
||||||
|
delay_all_reduce_named_params=delay_all_reduce_named_params,
|
||||||
|
param_to_hook_all_reduce=param_to_hook_all_reduce,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if not self.use_distributed:
|
||||||
|
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
||||||
|
return model
|
||||||
|
local_rank = get_rank()
|
||||||
|
model = DDP(
|
||||||
|
model,
|
||||||
|
device_ids=[local_rank],
|
||||||
|
output_device=local_rank,
|
||||||
|
**self._ddp_kwargs,
|
||||||
|
)
|
||||||
|
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.no_sync()
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module):
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.module.state_dict()
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@ExecutorFactory.register("fsdp")
|
||||||
|
class FSDPExecutor(BaseExecutor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grad_accum_steps: int = 1,
|
||||||
|
process_group=None,
|
||||||
|
sharding_strategy=None,
|
||||||
|
cpu_offload=None,
|
||||||
|
auto_wrap_policy=None,
|
||||||
|
backward_prefetch=None,
|
||||||
|
mixed_precision=None,
|
||||||
|
ignored_modules=None,
|
||||||
|
param_init_fn=None,
|
||||||
|
sync_module_states: bool = False,
|
||||||
|
forward_prefetch: bool = False,
|
||||||
|
limit_all_gathers: bool = True,
|
||||||
|
ignored_states=None,
|
||||||
|
device_mesh=None,
|
||||||
|
):
|
||||||
|
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||||
|
self._fsdp_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
process_group=process_group,
|
||||||
|
sharding_strategy=sharding_strategy,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
auto_wrap_policy=auto_wrap_policy,
|
||||||
|
backward_prefetch=backward_prefetch,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
ignored_modules=ignored_modules,
|
||||||
|
param_init_fn=param_init_fn,
|
||||||
|
sync_module_states=sync_module_states,
|
||||||
|
forward_prefetch=forward_prefetch,
|
||||||
|
limit_all_gathers=limit_all_gathers,
|
||||||
|
use_orig_params=True,
|
||||||
|
ignored_states=ignored_states,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
self._original_model: Optional[nn.Module] = None
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if not self.use_distributed:
|
||||||
|
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
||||||
|
return model
|
||||||
|
self._original_model = model
|
||||||
|
device_id = torch.device("cuda", get_rank())
|
||||||
|
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
||||||
|
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
if isinstance(model, FSDP):
|
||||||
|
return model.no_sync()
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module):
|
||||||
|
if isinstance(model, FSDP) and self.use_distributed:
|
||||||
|
with FSDP.state_dict_type(
|
||||||
|
model,
|
||||||
|
StateDictType.FULL_STATE_DICT,
|
||||||
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||||
|
):
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
return model.state_dict()
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelModel(nn.Module):
|
class ParallelModel(nn.Module):
|
||||||
|
|
@ -22,7 +22,7 @@ class RowParallelLinear(ParallelModel):
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
reduce_results: bool = True
|
reduce_results: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(process_group)
|
super().__init__(process_group)
|
||||||
|
|
||||||
|
|
@ -32,7 +32,9 @@ class RowParallelLinear(ParallelModel):
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
if in_features % self.world_size != 0:
|
if in_features % self.world_size != 0:
|
||||||
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}")
|
raise ValueError(
|
||||||
|
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
||||||
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
||||||
|
|
@ -49,8 +51,8 @@ class RowParallelLinear(ParallelModel):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
full_weight = state_dict.get('weight')
|
full_weight = state_dict.get("weight")
|
||||||
full_bias = state_dict.get('bias')
|
full_bias = state_dict.get("bias")
|
||||||
|
|
||||||
start_idx = self.rank * self.in_features_per_rank
|
start_idx = self.rank * self.in_features_per_rank
|
||||||
end_idx = start_idx + self.in_features_per_rank
|
end_idx = start_idx + self.in_features_per_rank
|
||||||
|
|
@ -68,7 +70,7 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
gather_results: bool = True
|
gather_results: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(process_group)
|
super().__init__(process_group)
|
||||||
|
|
||||||
|
|
@ -78,10 +80,16 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
self.gather_results = gather_results
|
self.gather_results = gather_results
|
||||||
|
|
||||||
if out_features % self.world_size != 0:
|
if out_features % self.world_size != 0:
|
||||||
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
raise ValueError(
|
||||||
|
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features))
|
self.weight = nn.Parameter(
|
||||||
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
torch.empty(self.out_features_per_rank, self.in_features)
|
||||||
|
)
|
||||||
|
self.bias = (
|
||||||
|
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
output = F.linear(input, self.weight, self.bias)
|
output = F.linear(input, self.weight, self.bias)
|
||||||
|
|
@ -94,8 +102,8 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
full_weight = state_dict.get('weight')
|
full_weight = state_dict.get("weight")
|
||||||
full_bias = state_dict.get('bias')
|
full_bias = state_dict.get("bias")
|
||||||
|
|
||||||
start_idx = self.rank * self.out_features_per_rank
|
start_idx = self.rank * self.out_features_per_rank
|
||||||
end_idx = start_idx + self.out_features_per_rank
|
end_idx = start_idx + self.out_features_per_rank
|
||||||
|
|
@ -1,28 +1,31 @@
|
||||||
import os
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from functools import wraps
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Callable, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_device():
|
def get_current_device():
|
||||||
return os.environ["LOCAL_DEVICE"]
|
return os.environ["LOCAL_DEVICE"]
|
||||||
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
return dist.get_world_size()
|
return dist.get_world_size()
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def get_rank() -> int:
|
def get_rank() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
return dist.get_rank()
|
return dist.get_rank()
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
rank: int,
|
rank: int,
|
||||||
|
|
@ -31,7 +34,6 @@ def setup_parallel(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None
|
|
||||||
):
|
):
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -42,30 +44,22 @@ def setup_parallel(
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
if device_ids is None:
|
device_id = torch.device(device_type, rank)
|
||||||
device_ids = [i for i in range(world_size)]
|
|
||||||
|
|
||||||
rank = device_ids[rank % len(device_ids)]
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
device_id = torch.device(device_type, device_ids[rank])
|
os.environ["MASTER_PORT"] = master_port
|
||||||
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
os.environ['MASTER_ADDR'] = master_addr
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ['MASTER_PORT'] = master_port
|
|
||||||
|
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
rank=rank,
|
rank=rank, world_size=world_size, backend=backend, device_id=device_id
|
||||||
world_size=world_size,
|
|
||||||
backend=backend,
|
|
||||||
device_id=device_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if backend == "nccl" and torch.cuda.is_available():
|
if backend == "nccl" and torch.cuda.is_available():
|
||||||
torch.cuda.set_device(device_id)
|
torch.cuda.set_device(device_id)
|
||||||
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
torch.xpu.set_device(device_id)
|
torch.xpu.set_device(device_id)
|
||||||
|
|
||||||
yield dist.group.WORLD
|
yield dist.group.WORLD
|
||||||
|
|
@ -73,6 +67,7 @@ def setup_parallel(
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def only_on_rank(rank, sync=False):
|
def only_on_rank(rank, sync=False):
|
||||||
"""
|
"""
|
||||||
decorator to run a function only on a specific rank.
|
decorator to run a function only on a specific rank.
|
||||||
|
|
@ -81,15 +76,20 @@ def only_on_rank(rank, sync=False):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
ret_args = None
|
||||||
if get_rank() == rank:
|
if get_rank() == rank:
|
||||||
return func(*args, **kwargs)
|
ret_args = func(*args, **kwargs)
|
||||||
if sync:
|
|
||||||
|
if sync and dist.is_available() and dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
return ret_args
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def wrapper_spawn_func(
|
def wrapper_spawn_func(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
|
@ -97,9 +97,8 @@ def wrapper_spawn_func(
|
||||||
master_addr: str,
|
master_addr: str,
|
||||||
master_port: str,
|
master_port: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
device_ids: List[int],
|
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
with setup_parallel(
|
with setup_parallel(
|
||||||
|
|
@ -109,7 +108,6 @@ def wrapper_spawn_func(
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
device_ids=device_ids
|
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
|
|
@ -117,6 +115,7 @@ def wrapper_spawn_func(
|
||||||
print(f"Error in rank {rank}: {e}")
|
print(f"Error in rank {rank}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def spawn_parallel_fn(
|
def spawn_parallel_fn(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
|
@ -124,28 +123,44 @@ def spawn_parallel_fn(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None,
|
start_method: str = "spawn",
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
# clear environment variables
|
# clear environment variables
|
||||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
|
for key in [
|
||||||
|
"MASTER_ADDR",
|
||||||
|
"MASTER_PORT",
|
||||||
|
"RANK",
|
||||||
|
"WORLD_SIZE",
|
||||||
|
"LOCAL_RANK",
|
||||||
|
"LOCAL_DEVICE",
|
||||||
|
]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
device_ids = device_ids or [0]
|
device_id = torch.device(device_type, 0)
|
||||||
deice_id = torch.device(device_type, device_ids[0])
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
os.environ["LOCAL_DEVICE"] = str(deice_id)
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
|
wrapper_spawn_func_args = (
|
||||||
device_type, device_ids, func, kwargs)
|
world_size,
|
||||||
|
backend,
|
||||||
mp.spawn(
|
master_addr,
|
||||||
wrapper_spawn_func,
|
master_port,
|
||||||
nprocs=world_size,
|
device_type,
|
||||||
args=wrapper_spawn_func_args,
|
func,
|
||||||
join=True
|
kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
mp.start_processes(
|
||||||
|
wrapper_spawn_func,
|
||||||
|
args=wrapper_spawn_func_args,
|
||||||
|
nprocs=world_size,
|
||||||
|
start_method=start_method,
|
||||||
|
join=True,
|
||||||
)
|
)
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
from astrai.preprocessing.builder import (
|
||||||
|
BaseMaskBuilder,
|
||||||
|
MaskBuilderFactory,
|
||||||
|
SectionedMaskBuilder,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseMaskBuilder",
|
||||||
|
"MaskBuilderFactory",
|
||||||
|
"SectionedMaskBuilder",
|
||||||
|
"Pipeline",
|
||||||
|
"filter_by_length",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,159 @@
|
||||||
|
"""Mask building strategies for preprocessing pipeline.
|
||||||
|
|
||||||
|
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||||
|
via declarative ``input.sections`` config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMaskBuilder(ABC):
|
||||||
|
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
||||||
|
|
||||||
|
Returns ``None`` to skip the item entirely.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: type):
|
||||||
|
if not issubclass(component_cls, BaseMaskBuilder):
|
||||||
|
raise TypeError(
|
||||||
|
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||||
|
if not domain_key:
|
||||||
|
return "__default__"
|
||||||
|
val = item.get(domain_key, "__default__")
|
||||||
|
return val if isinstance(val, str) else "__default__"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_action(action: str, role: str, config) -> str:
|
||||||
|
"""Resolve action to "train" or "mask".
|
||||||
|
|
||||||
|
- ``"train"`` / ``"mask"`` → literal
|
||||||
|
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||||||
|
"""
|
||||||
|
if action == "$role":
|
||||||
|
return config.mask.get(role, config.mask_default)
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@MaskBuilderFactory.register("sectioned")
|
||||||
|
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
|
"""Config-driven builder: iterates over ``input.sections`` in order.
|
||||||
|
|
||||||
|
Each section specifies a JSONL field + mask action.
|
||||||
|
|
||||||
|
Section spec::
|
||||||
|
|
||||||
|
{
|
||||||
|
"field": "messages", # JSONL key
|
||||||
|
"action": "$role", # "train" | "mask" | "$role"
|
||||||
|
"template": true, # apply chat_template per message (optional)
|
||||||
|
"add_special_tokens": false # override encode flag (optional)
|
||||||
|
}
|
||||||
|
|
||||||
|
Example configs::
|
||||||
|
|
||||||
|
# Chat
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "messages", "action": "$role", "template": true}
|
||||||
|
]}}
|
||||||
|
|
||||||
|
# Instruction
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
||||||
|
{"field": "response", "action": "train"}
|
||||||
|
]}}
|
||||||
|
|
||||||
|
# Text
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "text", "action": "train"}
|
||||||
|
]}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
sections = config.input.sections
|
||||||
|
if not sections:
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_ids: list[int] = []
|
||||||
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
|
has_template = any(s.get("template") for s in sections)
|
||||||
|
is_text_config = not has_template and all(
|
||||||
|
s["action"] == "train" for s in sections
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_template and tokenizer.bos_token_id is not None:
|
||||||
|
all_ids.append(tokenizer.bos_token_id)
|
||||||
|
loss_mask.append(0)
|
||||||
|
|
||||||
|
first_section = True
|
||||||
|
for sec in sections:
|
||||||
|
field = sec["field"]
|
||||||
|
action = sec["action"]
|
||||||
|
use_template = sec.get("template", False)
|
||||||
|
add_special = sec.get(
|
||||||
|
"add_special_tokens", not use_template and first_section
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_template:
|
||||||
|
messages = item.get(field)
|
||||||
|
if not isinstance(messages, list) or not messages:
|
||||||
|
continue
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
act = _resolve_action(action, role, config)
|
||||||
|
rendered = tokenizer.apply_chat_template(
|
||||||
|
[msg], tokenize=False, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||||
|
all_ids.extend(ids)
|
||||||
|
val = 1 if act == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
else:
|
||||||
|
text = str(item.get(field, ""))
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
if is_text_config:
|
||||||
|
pp = config.preprocessing
|
||||||
|
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||||||
|
continue
|
||||||
|
if len(text) > pp.max_chars:
|
||||||
|
continue
|
||||||
|
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||||||
|
all_ids.extend(ids)
|
||||||
|
val = 1 if action == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
|
||||||
|
first_section = False
|
||||||
|
|
||||||
|
max_len = config.preprocessing.max_seq_len
|
||||||
|
all_ids = all_ids[:max_len]
|
||||||
|
loss_mask = loss_mask[: len(all_ids)]
|
||||||
|
|
||||||
|
if not all_ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if has_template and len(all_ids) <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result: dict = {
|
||||||
|
"sequence": all_ids,
|
||||||
|
"domain": _extract_domain(item, config.output.domain_key),
|
||||||
|
}
|
||||||
|
if not all(m == 1 for m in loss_mask):
|
||||||
|
result["loss_mask"] = loss_mask
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""Config-driven JSONL preprocessing pipeline.
|
||||||
|
|
||||||
|
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||||
|
sharding and flush to ``.h5`` / ``.bin`` storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
|
from astrai.dataset.storage import save_bin, save_h5
|
||||||
|
from astrai.preprocessing.builder import SectionedMaskBuilder
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||||
|
"bool": torch.bool,
|
||||||
|
"uint8": torch.uint8,
|
||||||
|
"int8": torch.int8,
|
||||||
|
"int16": torch.int16,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||||
|
return min_len <= len(text) <= max_len
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||||
|
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PipelineConfig,
|
||||||
|
input_paths: list[str],
|
||||||
|
output_dir: str,
|
||||||
|
tokenizer_path: str,
|
||||||
|
):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
self.config = config
|
||||||
|
self.paths = input_paths
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.tokenizer_path = tokenizer_path
|
||||||
|
|
||||||
|
self.mask_builder = SectionedMaskBuilder()
|
||||||
|
|
||||||
|
def transform(self, item: dict) -> Optional[dict]:
|
||||||
|
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||||
|
domains: dict = defaultdict(lambda: defaultdict(list))
|
||||||
|
total_tokens = 0
|
||||||
|
shard_idx: dict[str, int] = defaultdict(int)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
pp = self.config.preprocessing
|
||||||
|
|
||||||
|
for item in tqdm.tqdm(
|
||||||
|
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
||||||
|
):
|
||||||
|
if pp.max_items and count >= pp.max_items:
|
||||||
|
break
|
||||||
|
|
||||||
|
result = self.transform(item)
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ids = result.pop("sequence")
|
||||||
|
if not ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
domain = result.pop("domain", "__default__")
|
||||||
|
result["sequence"] = ids
|
||||||
|
|
||||||
|
bucket = domains[domain]
|
||||||
|
for key in list(bucket.keys()):
|
||||||
|
if key not in result:
|
||||||
|
bucket[key].append([1] * len(ids))
|
||||||
|
for key, val in result.items():
|
||||||
|
bucket[key].append(val)
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
total_tokens += len(ids)
|
||||||
|
|
||||||
|
if total_tokens >= self.config.output.max_tokens_per_shard:
|
||||||
|
self._flush(domains, shard_idx)
|
||||||
|
domains.clear()
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if total_tokens > 0:
|
||||||
|
self._flush(domains, shard_idx)
|
||||||
|
|
||||||
|
print(f"Done. {count} documents tokenized.")
|
||||||
|
|
||||||
|
def _iter_items(self):
|
||||||
|
for path in self.paths:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
def _flush(self, domains, shard_idx):
|
||||||
|
for domain, keys in domains.items():
|
||||||
|
idx = shard_idx[domain]
|
||||||
|
tensors = {}
|
||||||
|
for key, ids_list in keys.items():
|
||||||
|
dt = _STR_TO_DTYPE.get(
|
||||||
|
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||||
|
)
|
||||||
|
tensors[key] = [
|
||||||
|
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||||
|
]
|
||||||
|
chunk_dir = os.path.join(self.output_dir, domain)
|
||||||
|
fmt = self.config.output.storage_format
|
||||||
|
if fmt == "bin":
|
||||||
|
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
|
||||||
|
else:
|
||||||
|
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||||
|
shard_idx[domain] = idx + 1
|
||||||
|
tqdm.tqdm.write(
|
||||||
|
f" saved {domain}/shard_{idx:04d} "
|
||||||
|
f"({tensors['sequence'][0].numel():,} tokens)"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class OptimizerProtocol(Protocol):
|
||||||
|
def step(self, closure=None): ...
|
||||||
|
def zero_grad(self): ...
|
||||||
|
@property
|
||||||
|
def param_groups(self) -> Any: ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, d: dict): ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SchedulerProtocol(Protocol):
|
||||||
|
def step(self): ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, d: dict): ...
|
||||||
|
def get_last_lr(self): ...
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from astrai.parallel.setup import get_rank
|
||||||
|
|
||||||
|
_META_FILE = "meta.json"
|
||||||
|
_CONFIG_FILE = "config.json"
|
||||||
|
_WEIGHTS_FILE = "model.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
||||||
|
st.save_file(state_dict, str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
return st.load_file(str(path))
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
state_dict = st.load_file(str(path))
|
||||||
|
else:
|
||||||
|
state_dict = {}
|
||||||
|
tmp = [state_dict]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(data: dict, path: Union[str, Path]):
|
||||||
|
with open(str(path), "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
tmp = [data]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
|
|
||||||
|
def save_torch(obj: Any, path: Union[str, Path]):
|
||||||
|
torch.save(obj, str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
rank = get_rank()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
raw = f.read()
|
||||||
|
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
||||||
|
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
||||||
|
else:
|
||||||
|
num_bytes = torch.tensor([0], dtype=torch.long)
|
||||||
|
|
||||||
|
dist.broadcast(num_bytes, src=0)
|
||||||
|
|
||||||
|
if rank != 0:
|
||||||
|
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
||||||
|
|
||||||
|
dist.broadcast(data_tensor, src=0)
|
||||||
|
|
||||||
|
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
||||||
|
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(config: dict, state_dict: dict, save_directory: str):
|
||||||
|
save_path = Path(save_directory)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_json(config, save_path / _CONFIG_FILE)
|
||||||
|
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_config(save_directory: str) -> dict:
|
||||||
|
return load_json(Path(save_directory) / _CONFIG_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(save_directory: str) -> dict:
|
||||||
|
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
path = Path(path)
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
return load_safetensors(path)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
state_dict = load_safetensors(path)
|
||||||
|
specs = [
|
||||||
|
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||||
|
for k in sorted(state_dict)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
state_dict = {}
|
||||||
|
specs = []
|
||||||
|
|
||||||
|
specs_list = [specs]
|
||||||
|
dist.broadcast_object_list(specs_list, src=0)
|
||||||
|
specs = specs_list[0]
|
||||||
|
|
||||||
|
for key, shape, dtype_name in specs:
|
||||||
|
dtype = getattr(torch, dtype_name)
|
||||||
|
if rank != 0:
|
||||||
|
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
||||||
|
else:
|
||||||
|
tensor = state_dict[key].contiguous().cpu()
|
||||||
|
dist.broadcast(tensor, src=0)
|
||||||
|
if rank != 0:
|
||||||
|
state_dict[key] = tensor
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Checkpoint:
|
||||||
|
state_dict: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
epoch: int = 0
|
||||||
|
iteration: int = 0
|
||||||
|
extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
config: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def save(self, save_dir: str):
|
||||||
|
save_path = Path(save_dir)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if get_rank() != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"iteration": self.iteration,
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
|
**self.meta,
|
||||||
|
}
|
||||||
|
save_json(meta, save_path / _META_FILE)
|
||||||
|
save_json(self.config, save_path / _CONFIG_FILE)
|
||||||
|
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||||
|
for key, value in self.extra.items():
|
||||||
|
save_torch(value, save_path / f"{key}.pt")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||||
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
|
meta = load_json(save_path / _META_FILE, broadcast)
|
||||||
|
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
||||||
|
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
||||||
|
|
||||||
|
extra = {}
|
||||||
|
for f in sorted(save_path.iterdir()):
|
||||||
|
if f.suffix == ".pt":
|
||||||
|
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
state_dict=state_dict,
|
||||||
|
epoch=meta.get("epoch", 0),
|
||||||
|
iteration=meta.get("iteration", 0),
|
||||||
|
extra=extra,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
||||||
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AutoTokenizer",
|
||||||
|
"ChatTemplate",
|
||||||
|
"MessageType",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
type MessageType = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplate:
|
||||||
|
"""A chat template with Jinja2 rendering support.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Unique identifier for the template.
|
||||||
|
template_str: Jinja2 template string.
|
||||||
|
description: Optional description.
|
||||||
|
default_variables: Optional dictionary of default variable values.
|
||||||
|
special_tokens: Optional dictionary mapping token names to their string values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "",
|
||||||
|
template_str: str = "",
|
||||||
|
description: str = "",
|
||||||
|
default_variables: Optional[Dict[str, Any]] = None,
|
||||||
|
special_tokens: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.template_str = template_str
|
||||||
|
self.description = description
|
||||||
|
self.default_variables = default_variables or {}
|
||||||
|
self.special_tokens = special_tokens or {}
|
||||||
|
self._compiled: Template = Template(template_str)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(
|
||||||
|
cls,
|
||||||
|
template_str: str,
|
||||||
|
description: str = "",
|
||||||
|
default_variables: Optional[Dict[str, Any]] = None,
|
||||||
|
special_tokens: Optional[Dict[str, str]] = None,
|
||||||
|
) -> "ChatTemplate":
|
||||||
|
"""Create a ChatTemplate instance directly from a template string."""
|
||||||
|
return cls(
|
||||||
|
name="",
|
||||||
|
template_str=template_str,
|
||||||
|
description=description,
|
||||||
|
default_variables=default_variables,
|
||||||
|
special_tokens=special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
messages: List[MessageType],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
**extra_variables: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Render the template with given messages and variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
system_prompt: Optional system prompt string.
|
||||||
|
**extra_variables: Additional variables to pass to the template.
|
||||||
|
These override default_variables and special_tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered prompt string.
|
||||||
|
"""
|
||||||
|
# Merge default variables, special tokens, and extra variables
|
||||||
|
variables = {**self.default_variables, **self.special_tokens, **extra_variables}
|
||||||
|
variables["messages"] = messages
|
||||||
|
if system_prompt is not None:
|
||||||
|
variables["system_prompt"] = system_prompt
|
||||||
|
|
||||||
|
return self._compiled.render(**variables)
|
||||||
|
|
@ -0,0 +1,264 @@
|
||||||
|
"""
|
||||||
|
Tokenizer module with implementation and auto-loading support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
from astrai.tokenize.chat_template import ChatTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class AutoTokenizer:
|
||||||
|
"""Base tokenizer class with automatic loading support"""
|
||||||
|
|
||||||
|
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: Optional[Union[str, Path]] = None,
|
||||||
|
special_token_map: Optional[Dict[str, str]] = None,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._tokenizer: Tokenizer = None
|
||||||
|
self._chat_template: Optional[ChatTemplate] = None
|
||||||
|
self._special_token_map: Optional[Dict] = special_token_map or {}
|
||||||
|
|
||||||
|
if chat_template:
|
||||||
|
self.set_chat_template(chat_template)
|
||||||
|
|
||||||
|
if path:
|
||||||
|
self.load(path)
|
||||||
|
|
||||||
|
def load(self, path: Union[str, Path]):
|
||||||
|
"""Load tokenizer from directory."""
|
||||||
|
path = Path(path)
|
||||||
|
tokenizer_file = path / "tokenizer.json"
|
||||||
|
config_file = path / "tokenizer_config.json"
|
||||||
|
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
||||||
|
|
||||||
|
if config_file.exists():
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
if "special_tokens" in config:
|
||||||
|
self._special_token_map.update(config["special_tokens"])
|
||||||
|
|
||||||
|
# Load chat template from config
|
||||||
|
if "chat_template" in config:
|
||||||
|
self.set_chat_template(config["chat_template"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
|
||||||
|
"""Load tokenizer from pretrained directory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If tokenizer.json is missing.
|
||||||
|
RuntimeError: If tokenizer failed to initialize.
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
tokenizer_file = path / "tokenizer.json"
|
||||||
|
if not tokenizer_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Tokenizer file not found: {tokenizer_file}. "
|
||||||
|
"A valid tokenizer.json is required."
|
||||||
|
)
|
||||||
|
instance = cls(path)
|
||||||
|
if instance._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load tokenizer from {path}. "
|
||||||
|
"The tokenizer.json may be corrupted or incompatible."
|
||||||
|
)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def save_pretrained(self, save_path: str):
|
||||||
|
"""
|
||||||
|
Save tokenizer to pretrained directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path: Path to save the tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||||
|
)
|
||||||
|
|
||||||
|
save_path = Path(save_path)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save tokenizer
|
||||||
|
self._tokenizer.save(str(save_path / "tokenizer.json"))
|
||||||
|
|
||||||
|
# Save tokenizer config
|
||||||
|
config = {}
|
||||||
|
if self._special_token_map is not None:
|
||||||
|
config["special_tokens"] = self._special_token_map
|
||||||
|
if self._chat_template is not None:
|
||||||
|
config["chat_template"] = self._chat_template.template_str
|
||||||
|
|
||||||
|
with open(save_path / "tokenizer_config.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_tokenizer(cls, name: str, tokenizer_class: type):
|
||||||
|
"""
|
||||||
|
Register a new tokenizer class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name to register the tokenizer class under
|
||||||
|
tokenizer_class: The tokenizer class to register
|
||||||
|
"""
|
||||||
|
cls.TOKENIZER_CLASSES[name] = tokenizer_class
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
tokens: Union[str, List[str]],
|
||||||
|
out_ids: bool = True,
|
||||||
|
is_pretokenized: bool = False,
|
||||||
|
add_special_tokens: bool = True,
|
||||||
|
) -> List:
|
||||||
|
"""Encode text to tokens or token IDs."""
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
encoded = self._tokenizer.encode(
|
||||||
|
tokens,
|
||||||
|
is_pretokenized=is_pretokenized,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
return encoded.ids if out_ids else encoded.tokens
|
||||||
|
else:
|
||||||
|
encoded_list = self._tokenizer.encode_batch(
|
||||||
|
tokens,
|
||||||
|
is_pretokenized=is_pretokenized,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
||||||
|
]
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
||||||
|
"""Decode token IDs to text."""
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self._tokenizer is None:
|
||||||
|
return 0
|
||||||
|
return self._tokenizer.get_vocab_size()
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
"""
|
||||||
|
Dynamically intercept special token attribute access.
|
||||||
|
Supports three forms:
|
||||||
|
- tokenizer.bos_token → returns string
|
||||||
|
- tokenizer.bos_token_id → returns corresponding integer ID
|
||||||
|
- tokenizer.stop_ids → returns list of corresponding integer IDs for all special tokens
|
||||||
|
"""
|
||||||
|
# Handle stop_ids - return IDs for all special tokens
|
||||||
|
if key == "stop_ids":
|
||||||
|
stop_ids = []
|
||||||
|
|
||||||
|
if self._tokenizer is None:
|
||||||
|
return stop_ids
|
||||||
|
|
||||||
|
for val in self._special_token_map.values():
|
||||||
|
token_id = self._tokenizer.token_to_id(val)
|
||||||
|
if token_id is not None:
|
||||||
|
stop_ids.append(token_id)
|
||||||
|
|
||||||
|
return stop_ids
|
||||||
|
|
||||||
|
# Handle _id suffix (e.g., bos_token_id -> bos_token)
|
||||||
|
if key.endswith("_id"):
|
||||||
|
base_attr = key[:-3] # Remove "_id"
|
||||||
|
token_str = self._special_token_map.get(base_attr)
|
||||||
|
if token_str is None:
|
||||||
|
return None
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
|
||||||
|
return self._tokenizer.token_to_id(token_str)
|
||||||
|
|
||||||
|
# Handle regular string attributes
|
||||||
|
if key in self._special_token_map:
|
||||||
|
return self._special_token_map.get(key)
|
||||||
|
|
||||||
|
# Other attributes trigger default AttributeError
|
||||||
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
def set_chat_template(self, template: Union[str, ChatTemplate]):
|
||||||
|
"""
|
||||||
|
Set the chat template for the tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: Either a template name (str) registered in the global registry,
|
||||||
|
or a ChatTemplate instance, or a Jinja2 template string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If template name is not registered.
|
||||||
|
"""
|
||||||
|
if isinstance(template, str):
|
||||||
|
self._chat_template = ChatTemplate.from_string(template)
|
||||||
|
elif isinstance(template, ChatTemplate):
|
||||||
|
self._chat_template = template
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid template type, must be str or ChatTemplate.")
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tokenize: bool = True,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[str, List[int]]:
|
||||||
|
"""
|
||||||
|
Apply the chat template to messages and optionally tokenize the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
system_prompt: Optional system prompt string (auto-converted to first message).
|
||||||
|
tokenize: Whether to return token IDs (True) or raw string (False).
|
||||||
|
add_generation_prompt: Whether to add the generation prompt (default: True).
|
||||||
|
**kwargs: Additional variables to pass to the template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either the rendered string or list of token IDs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If chat template is not set.
|
||||||
|
"""
|
||||||
|
if self._chat_template is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Chat template not set. Use set_chat_template() to set a template first."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-convert system_prompt to first message if provided
|
||||||
|
if system_prompt:
|
||||||
|
messages = [{"role": "system", "content": system_prompt}] + list(messages)
|
||||||
|
|
||||||
|
# Render the template
|
||||||
|
rendered = self._chat_template.render(
|
||||||
|
messages=messages,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenize:
|
||||||
|
return self.encode(rendered)
|
||||||
|
|
||||||
|
return rendered
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
from astrai.trainer.optim import Muon
|
||||||
|
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
from astrai.trainer.train_callback import (
|
||||||
|
CallbackFactory,
|
||||||
|
TrainCallback,
|
||||||
|
)
|
||||||
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Main trainer
|
||||||
|
"Trainer",
|
||||||
|
# Optimizer
|
||||||
|
"Muon",
|
||||||
|
# Strategy factory
|
||||||
|
"StrategyFactory",
|
||||||
|
"BaseStrategy",
|
||||||
|
# Scheduler factory
|
||||||
|
"SchedulerFactory",
|
||||||
|
"BaseScheduler",
|
||||||
|
# Callback factory
|
||||||
|
"TrainCallback",
|
||||||
|
"CallbackFactory",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def _grad_stat(
|
||||||
|
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
|
||||||
|
) -> dict:
|
||||||
|
results = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
results[name] = default
|
||||||
|
if param.grad is not None:
|
||||||
|
results[name] = fn(param.grad.data)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||||
|
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||||
|
return _grad_stat(model, lambda g: g.std().item(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_max(model: nn.Module) -> Dict[str, float]:
|
||||||
|
return _grad_stat(model, lambda g: g.max().item(), -float("inf"))
|
||||||
|
|
||||||
|
|
||||||
|
def grad_min(model: nn.Module) -> Dict[str, float]:
|
||||||
|
return _grad_stat(model, lambda g: g.min().item(), float("inf"))
|
||||||
|
|
||||||
|
|
||||||
|
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||||
|
return _grad_stat(model, lambda g: g.mean().item(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||||
|
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0)
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_loss(ctx):
|
||||||
|
return ctx.loss
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_lr(ctx):
|
||||||
|
return ctx.optimizer.param_groups[-1]["lr"]
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_val_loss(ctx):
|
||||||
|
return ctx.val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_grad_norm(ctx):
|
||||||
|
return grad_norm(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_grad_std(ctx):
|
||||||
|
return grad_std(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_grad_max(ctx):
|
||||||
|
return grad_max(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_grad_min(ctx):
|
||||||
|
return grad_min(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_grad_mean(ctx):
|
||||||
|
return grad_mean(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_grad_nan_num(ctx):
|
||||||
|
return grad_nan_num(ctx.model)
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
import torch
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||||
|
assert G.ndim == 2
|
||||||
|
X = G
|
||||||
|
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||||
|
X = X / (X.norm() + 1e-7) * scale
|
||||||
|
if steps == 0:
|
||||||
|
return X
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.T
|
||||||
|
B = A @ X
|
||||||
|
X = a * X + b * B + c * (A @ B)
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr: float = 2e-3,
|
||||||
|
momentum: float = 0.95,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
adamw_lr: float = None,
|
||||||
|
adamw_betas: tuple = (0.9, 0.95),
|
||||||
|
adamw_eps: float = 1e-8,
|
||||||
|
adamw_wd: float = 0.0,
|
||||||
|
):
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
momentum=momentum,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
nesterov=nesterov,
|
||||||
|
ns_steps=ns_steps,
|
||||||
|
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
||||||
|
adamw_betas=adamw_betas,
|
||||||
|
adamw_eps=adamw_eps,
|
||||||
|
adamw_wd=adamw_wd,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params_2d, params_1d = [], []
|
||||||
|
grads_2d, grads_1d = [], []
|
||||||
|
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
if p.grad.is_sparse:
|
||||||
|
raise RuntimeError("Muon does not support sparse gradients")
|
||||||
|
if p.ndim >= 2:
|
||||||
|
params_2d.append(p)
|
||||||
|
grads_2d.append(p.grad)
|
||||||
|
else:
|
||||||
|
params_1d.append(p)
|
||||||
|
grads_1d.append(p.grad)
|
||||||
|
|
||||||
|
if params_2d:
|
||||||
|
self._muon_update_foreach(params_2d, grads_2d, group)
|
||||||
|
if params_1d:
|
||||||
|
self._adamw_update_foreach(params_1d, grads_1d, group)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
||||||
|
lr = group["lr"]
|
||||||
|
momentum = group["momentum"]
|
||||||
|
wd = group["weight_decay"]
|
||||||
|
nesterov = group["nesterov"]
|
||||||
|
ns_steps = group["ns_steps"]
|
||||||
|
|
||||||
|
if wd != 0:
|
||||||
|
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
||||||
|
|
||||||
|
if nesterov:
|
||||||
|
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
||||||
|
|
||||||
|
bufs = []
|
||||||
|
for p, grad in zip(params_2d, grads_2d):
|
||||||
|
state = self.state[p]
|
||||||
|
if "momentum_buffer" not in state:
|
||||||
|
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||||
|
bufs.append(state["momentum_buffer"])
|
||||||
|
|
||||||
|
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
||||||
|
|
||||||
|
for p, buf in zip(params_2d, bufs):
|
||||||
|
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||||
|
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||||
|
p.add_(update, alpha=-lr * scale)
|
||||||
|
|
||||||
|
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
||||||
|
lr = group["adamw_lr"]
|
||||||
|
betas = group["adamw_betas"]
|
||||||
|
eps = group["adamw_eps"]
|
||||||
|
wd = group["adamw_wd"]
|
||||||
|
|
||||||
|
steps: list[int] = []
|
||||||
|
exp_avgs, exp_avg_sqs = [], []
|
||||||
|
has_state = []
|
||||||
|
for p in params_1d:
|
||||||
|
state = self.state[p]
|
||||||
|
if not state:
|
||||||
|
state["step"] = 0
|
||||||
|
state["exp_avg"] = torch.zeros_like(p)
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
|
has_state.append(False)
|
||||||
|
else:
|
||||||
|
has_state.append(True)
|
||||||
|
state["step"] += 1
|
||||||
|
steps.append(state["step"])
|
||||||
|
exp_avgs.append(state["exp_avg"])
|
||||||
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||||
|
|
||||||
|
beta1, beta2 = betas
|
||||||
|
|
||||||
|
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
||||||
|
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
||||||
|
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
||||||
|
|
||||||
|
bias_correction1 = [1 - beta1**s for s in steps]
|
||||||
|
bias_correction2 = [1 - beta2**s for s in steps]
|
||||||
|
|
||||||
|
if wd != 0:
|
||||||
|
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
||||||
|
|
||||||
|
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
||||||
|
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
||||||
|
denom = torch._foreach_sqrt(denom)
|
||||||
|
torch._foreach_add_(denom, eps)
|
||||||
|
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|
||||||
|
|
@ -0,0 +1,194 @@
|
||||||
|
"""Learning rate scheduler implementations with factory pattern."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Type
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScheduler(LRScheduler, ABC):
|
||||||
|
"""Base scheduler class for all other schedulers."""
|
||||||
|
|
||||||
|
def __init__(self, optimizer, last_epoch: int = -1):
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_lr(self) -> List[float]:
|
||||||
|
"""Calculate the current learning rate."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
|
return super().state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||||
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
|
"""Factory class for creating learning rate schedulers.
|
||||||
|
|
||||||
|
Supports decorator-based registration for extensible scheduler types.
|
||||||
|
Also supports creation from ScheduleConfig objects.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
@SchedulerFactory.register("custom")
|
||||||
|
class CustomScheduler(BaseScheduler):
|
||||||
|
...
|
||||||
|
|
||||||
|
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||||
|
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||||
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
|
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls, optimizer, schedule_type: str = "none", **kwargs
|
||||||
|
) -> "BaseScheduler":
|
||||||
|
"""Create a scheduler instance by type name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: PyTorch optimizer
|
||||||
|
schedule_type: Type of scheduler ("cosine", "sgdr")
|
||||||
|
**kwargs: Arguments passed to the scheduler constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scheduler instance
|
||||||
|
"""
|
||||||
|
return super().create(schedule_type, optimizer, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available_types(cls) -> list:
|
||||||
|
"""Return list of registered scheduler type names."""
|
||||||
|
return cls.list_registered()
|
||||||
|
|
||||||
|
|
||||||
|
# ----------- Scheduler implementations -----------
|
||||||
|
|
||||||
|
|
||||||
|
@SchedulerFactory.register("cosine")
|
||||||
|
class CosineScheduler(BaseScheduler):
|
||||||
|
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
warmup_steps: int,
|
||||||
|
lr_decay_steps: int,
|
||||||
|
min_rate: float = 0.05,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
):
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.lr_decay_steps = lr_decay_steps
|
||||||
|
self.min_rate = min_rate
|
||||||
|
self.total_steps = warmup_steps + lr_decay_steps
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self) -> List[float]:
|
||||||
|
# warmup
|
||||||
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
||||||
|
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
# cosine decay
|
||||||
|
decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps
|
||||||
|
decay_progress = min(decay_progress, 1.0)
|
||||||
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
|
||||||
|
decay_factor = max(self.min_rate, cosine_decay)
|
||||||
|
return [base_lr * decay_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
state = super().state_dict()
|
||||||
|
state.update(
|
||||||
|
{
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
"lr_decay_steps": self.lr_decay_steps,
|
||||||
|
"min_rate": self.min_rate,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||||
|
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
|
||||||
|
self.min_rate = state_dict.pop("min_rate")
|
||||||
|
self.total_steps = state_dict.pop("total_steps")
|
||||||
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@SchedulerFactory.register("sgdr")
|
||||||
|
class SGDRScheduler(BaseScheduler):
|
||||||
|
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
warmup_steps: int,
|
||||||
|
cycle_length: int,
|
||||||
|
min_rate: float = 0.05,
|
||||||
|
t_mult: int = 2,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
):
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.cycle_length = cycle_length
|
||||||
|
self.min_rate = min_rate
|
||||||
|
self.t_mult = t_mult
|
||||||
|
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
# warmup
|
||||||
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
||||||
|
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
# SGDR
|
||||||
|
steps_since_warmup = self.last_epoch - self.warmup_steps
|
||||||
|
|
||||||
|
# 1. Calculate current cycle and position within cycle
|
||||||
|
current_cycle_length = self.cycle_length
|
||||||
|
total_cycles_length = 0
|
||||||
|
cycle_num = 0
|
||||||
|
|
||||||
|
while total_cycles_length + current_cycle_length <= steps_since_warmup:
|
||||||
|
total_cycles_length += current_cycle_length
|
||||||
|
current_cycle_length *= self.t_mult
|
||||||
|
cycle_num += 1
|
||||||
|
|
||||||
|
steps_in_cycle = steps_since_warmup - total_cycles_length
|
||||||
|
|
||||||
|
# 2. Cosine annealing within the current cycle
|
||||||
|
cosine_factor = 0.5 * (
|
||||||
|
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
|
||||||
|
)
|
||||||
|
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
||||||
|
|
||||||
|
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Returns the state of the scheduler as a dict."""
|
||||||
|
state = super().state_dict()
|
||||||
|
state.update(
|
||||||
|
{
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
"cycle_length": self.cycle_length,
|
||||||
|
"min_rate": self.min_rate,
|
||||||
|
"t_mult": self.t_mult,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Loads the scheduler's state."""
|
||||||
|
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||||
|
self.cycle_length = state_dict.pop("cycle_length")
|
||||||
|
self.min_rate = state_dict.pop("min_rate")
|
||||||
|
self.t_mult = state_dict.pop("t_mult")
|
||||||
|
super().load_state_dict(state_dict)
|
||||||
|
|
@ -0,0 +1,334 @@
|
||||||
|
"""Training strategy implementations with factory pattern."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||||
|
"""Create a frozen reference model from model_fn + full state dict."""
|
||||||
|
ref_model = model_fn()
|
||||||
|
ref_model.load_state_dict(state_dict)
|
||||||
|
ref_model.requires_grad_(False)
|
||||||
|
ref_model.eval()
|
||||||
|
return ref_model
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
||||||
|
"""Move batch tensors to specified device with non-blocking transfer."""
|
||||||
|
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def get_logprobs(
|
||||||
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
|
input_ids: Tensor,
|
||||||
|
mask: Tensor,
|
||||||
|
reduction: str,
|
||||||
|
):
|
||||||
|
"""Compute token-wise log probabilities from model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The language model
|
||||||
|
input_ids: Input token IDs of shape [batch_size, seq_len]
|
||||||
|
mask: Attention mask of shape [batch_size, seq_len]
|
||||||
|
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Log probabilities with reduction applied over sequence dimension
|
||||||
|
"""
|
||||||
|
allowed_reductions = ["mean", "sum", "none"]
|
||||||
|
if reduction not in allowed_reductions:
|
||||||
|
raise ValueError(
|
||||||
|
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
shifted_input_ids = input_ids[:, 1:]
|
||||||
|
shifted_mask = mask[:, 1:]
|
||||||
|
|
||||||
|
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
||||||
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
|
token_logprobs = torch.gather(
|
||||||
|
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
if reduction == "mean":
|
||||||
|
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
|
||||||
|
dim=-1
|
||||||
|
).clamp(min=1.0)
|
||||||
|
elif reduction == "sum":
|
||||||
|
return (token_logprobs * shifted_mask).sum(dim=-1)
|
||||||
|
else:
|
||||||
|
return token_logprobs * shifted_mask
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStrategy(ABC):
|
||||||
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.device = device
|
||||||
|
self.executor = kwargs.pop("executor", None)
|
||||||
|
self.model_fn = kwargs.pop("model_fn", None)
|
||||||
|
self.extra_kwargs = kwargs
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Compute loss for the given batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Dictionary containing batch tensors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Computed loss tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Allow calling strategy directly as a callable."""
|
||||||
|
return self.compute_loss(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
|
"""Factory class for creating training strategy instances.
|
||||||
|
|
||||||
|
Supports decorator-based registration for extensible strategy types.
|
||||||
|
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
@StrategyFactory.register("custom")
|
||||||
|
class CustomStrategy(BaseStrategy):
|
||||||
|
...
|
||||||
|
|
||||||
|
strategy = StrategyFactory.create("custom", model, device)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, strategy_cls: type):
|
||||||
|
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||||
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
|
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
||||||
|
"""Create a strategy instance based on training type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||||
|
model: Model instance for the strategy
|
||||||
|
device: Device to run the strategy on
|
||||||
|
**kwargs: Additional arguments passed to strategy constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Strategy instance
|
||||||
|
"""
|
||||||
|
return super().create(train_type, model, device, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available_strategies(cls) -> list:
|
||||||
|
"""Return list of registered strategy names."""
|
||||||
|
return cls.list_registered()
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Strategy Classes ==============
|
||||||
|
# All strategies are registered at class definition time using the decorator
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("seq")
|
||||||
|
class SEQStrategy(BaseStrategy):
|
||||||
|
"""Standard next-token prediction training strategy.
|
||||||
|
|
||||||
|
Computes cross-entropy loss for next token prediction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||||
|
super().__init__(model, device, **kwargs)
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||||
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
|
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
input=logits.flatten(0, 1).float(),
|
||||||
|
target=target_ids.flatten(),
|
||||||
|
label_smoothing=self.label_smoothing,
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("sft")
|
||||||
|
class SFTStrategy(BaseStrategy):
|
||||||
|
"""Supervised Fine-tuning strategy with loss masking.
|
||||||
|
|
||||||
|
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||||
|
super().__init__(model, device, **kwargs)
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
|
input_ids, target_ids, loss_mask = (
|
||||||
|
batch["input_ids"],
|
||||||
|
batch["target_ids"],
|
||||||
|
batch["loss_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
|
ignore_index = -100
|
||||||
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
input=logits.flatten(0, 1).float(),
|
||||||
|
target=target_ids.flatten(),
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
label_smoothing=self.label_smoothing,
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("dpo")
|
||||||
|
class DPOStrategy(BaseStrategy):
|
||||||
|
"""Direct Preference Optimization strategy.
|
||||||
|
|
||||||
|
Implements the DPO loss from the paper "Direct Preference Optimization".
|
||||||
|
Uses a reference model to compute KL divergence penalty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
device: str,
|
||||||
|
beta: float = 0.1,
|
||||||
|
reduction: str = "mean",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model, device, **kwargs)
|
||||||
|
self.ref_model = create_ref_model(
|
||||||
|
self.model_fn, self.executor.unwrap_model(model)
|
||||||
|
).to(device=self.device)
|
||||||
|
self.beta = beta
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
|
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
||||||
|
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||||
|
|
||||||
|
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||||
|
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||||
|
|
||||||
|
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
log_ref = get_logprobs(
|
||||||
|
self.ref_model, concat_ids, concat_mask, self.reduction
|
||||||
|
)
|
||||||
|
|
||||||
|
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
||||||
|
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
|
||||||
|
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
|
||||||
|
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
|
||||||
|
|
||||||
|
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
||||||
|
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
||||||
|
|
||||||
|
ratio_diff = pi_log_ratio - ref_log_ratio
|
||||||
|
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
||||||
|
|
||||||
|
return dpo_loss
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("grpo")
|
||||||
|
class GRPOStrategy(BaseStrategy):
|
||||||
|
"""Group Relative Policy Optimization strategy.
|
||||||
|
|
||||||
|
On-policy GRPO following DeepSeek-R1: the policy model is updated while
|
||||||
|
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
|
||||||
|
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
device: str,
|
||||||
|
clip_eps: float = 0.2,
|
||||||
|
kl_coef: float = 0.01,
|
||||||
|
group_size: int = 4,
|
||||||
|
reduction: str = "mean",
|
||||||
|
sync_interval: int = 200,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model, device, **kwargs)
|
||||||
|
self.ref_model = create_ref_model(
|
||||||
|
self.model_fn, self.executor.unwrap_model(model)
|
||||||
|
).to(device=self.device)
|
||||||
|
self.clip_eps = clip_eps
|
||||||
|
self.kl_coef = kl_coef
|
||||||
|
self.group_size = group_size
|
||||||
|
self.reduction = reduction
|
||||||
|
self.sync_interval = sync_interval
|
||||||
|
self._step = 0
|
||||||
|
|
||||||
|
def sync_ref_model(self):
|
||||||
|
"""Copy current model weights to ref model."""
|
||||||
|
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
self._step += 1
|
||||||
|
if self._step % self.sync_interval == 0:
|
||||||
|
self.sync_ref_model()
|
||||||
|
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
|
prompts = batch["prompts"]
|
||||||
|
responses = batch["responses"]
|
||||||
|
masks = batch["masks"]
|
||||||
|
rewards = batch["rewards"]
|
||||||
|
|
||||||
|
batch_size, group_size, response_len = responses.shape
|
||||||
|
responses_flat = responses.view(-1, response_len)
|
||||||
|
masks_flat = masks.view(-1, response_len)
|
||||||
|
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
||||||
|
|
||||||
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||||
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||||
|
|
||||||
|
log_probs_policy = get_logprobs(
|
||||||
|
self.model, full_sequences, full_masks, self.reduction
|
||||||
|
)
|
||||||
|
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
log_probs_ref = get_logprobs(
|
||||||
|
self.ref_model, full_sequences, full_masks, self.reduction
|
||||||
|
)
|
||||||
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||||
|
|
||||||
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||||
|
mean = rewards.mean(dim=-1, keepdim=True)
|
||||||
|
std = rewards.std(dim=-1, keepdim=True)
|
||||||
|
advantages = (rewards - mean) / (std + eps)
|
||||||
|
|
||||||
|
ratio = torch.exp(log_probs_policy - log_probs_ref)
|
||||||
|
|
||||||
|
surr1 = ratio * advantages
|
||||||
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
|
|
||||||
|
policy_loss = -torch.min(surr1, surr2).mean()
|
||||||
|
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
||||||
|
total_loss = policy_loss + kl_penalty
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
@ -0,0 +1,333 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.parallel import only_on_rank
|
||||||
|
from astrai.parallel.setup import get_current_device, get_rank
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
|
from astrai.trainer.metric_util import (
|
||||||
|
ctx_get_grad_max,
|
||||||
|
ctx_get_grad_mean,
|
||||||
|
ctx_get_grad_min,
|
||||||
|
ctx_get_grad_nan_num,
|
||||||
|
ctx_get_grad_norm,
|
||||||
|
ctx_get_grad_std,
|
||||||
|
ctx_get_loss,
|
||||||
|
ctx_get_lr,
|
||||||
|
ctx_get_val_loss,
|
||||||
|
)
|
||||||
|
from astrai.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class TrainCallback(Protocol):
|
||||||
|
"""
|
||||||
|
Callback interface for trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_train_begin(self, context: TrainContext):
|
||||||
|
"""Called at the beginning of training."""
|
||||||
|
|
||||||
|
def on_train_end(self, context: TrainContext):
|
||||||
|
"""Called at the end of training."""
|
||||||
|
|
||||||
|
def on_epoch_begin(self, context: TrainContext):
|
||||||
|
"""Called at the beginning of each epoch."""
|
||||||
|
|
||||||
|
def on_epoch_end(self, context: TrainContext):
|
||||||
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
|
def on_batch_begin(self, context: TrainContext):
|
||||||
|
"""Called at the beginning of each batch."""
|
||||||
|
|
||||||
|
def on_batch_end(self, context: TrainContext):
|
||||||
|
"""Called at the end of each batch."""
|
||||||
|
|
||||||
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
|
"""Called on every optimizer step (sync step only)."""
|
||||||
|
|
||||||
|
def on_error(self, context: TrainContext):
|
||||||
|
"""Called when an error occurs during training."""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackFactory(BaseFactory[TrainCallback]):
|
||||||
|
"""Factory for registering and creating training callbacks.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@CallbackFactory.register("my_callback")
|
||||||
|
class MyCallback(TrainCallback):
|
||||||
|
...
|
||||||
|
|
||||||
|
callback = CallbackFactory.create("my_callback", **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("gradient_clipping")
|
||||||
|
class GradientClippingCallback(TrainCallback):
|
||||||
|
"""
|
||||||
|
Gradient clipping callback for trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_grad_norm: float):
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("gradient_checkpointing")
|
||||||
|
class GradientCheckpointingCallback(TrainCallback):
|
||||||
|
"""
|
||||||
|
Activation checkpointing callback — trades compute for memory
|
||||||
|
by recomputing specified module activations during the backward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules: Module types to apply checkpointing to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, modules: Optional[List[type]] = None):
|
||||||
|
self.modules = tuple(modules) if modules else ()
|
||||||
|
|
||||||
|
def _enable(self, module: nn.Module):
|
||||||
|
if self.modules and isinstance(module, self.modules):
|
||||||
|
fn = module.forward
|
||||||
|
module._original_forward = fn
|
||||||
|
module.forward = lambda *a, **kw: torch_checkpoint(
|
||||||
|
fn, *a, use_reentrant=False, **kw
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _disable(module: nn.Module):
|
||||||
|
if hasattr(module, "_original_forward"):
|
||||||
|
module.forward = module._original_forward
|
||||||
|
del module._original_forward
|
||||||
|
|
||||||
|
def on_train_begin(self, context: TrainContext):
|
||||||
|
context.model.apply(self._enable)
|
||||||
|
logger.info("Gradient checkpointing enabled")
|
||||||
|
|
||||||
|
def on_train_end(self, context: TrainContext):
|
||||||
|
context.model.apply(self._disable)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("checkpoint")
|
||||||
|
class CheckpointCallback(TrainCallback):
|
||||||
|
"""
|
||||||
|
Checkpoint callback for trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
extra_keys = ("optimizer", "scheduler")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
interval: int,
|
||||||
|
weight_only: bool = False,
|
||||||
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
|
):
|
||||||
|
self.save_dir = save_dir
|
||||||
|
self.interval = interval
|
||||||
|
self.weight_only = weight_only
|
||||||
|
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||||
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
|
state_dict = context.executor.unwrap_model(context.model)
|
||||||
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
|
if get_rank() == 0:
|
||||||
|
save_path = os.path.join(
|
||||||
|
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||||
|
)
|
||||||
|
extra = self.save_extra_fn(context)
|
||||||
|
context.checkpoint = Checkpoint(
|
||||||
|
state_dict=state_dict,
|
||||||
|
epoch=context.epoch,
|
||||||
|
iteration=context.iteration,
|
||||||
|
extra=extra,
|
||||||
|
config=context.model_config,
|
||||||
|
)
|
||||||
|
context.checkpoint.save(save_path)
|
||||||
|
|
||||||
|
def on_batch_end(self, context: TrainContext):
|
||||||
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
|
def on_train_end(self, context: TrainContext):
|
||||||
|
if context.iteration != self.last_ckpt_iter:
|
||||||
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
|
def on_error(self, context: TrainContext):
|
||||||
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_extra(context: TrainContext) -> dict:
|
||||||
|
extra = {}
|
||||||
|
for name in CheckpointCallback.extra_keys:
|
||||||
|
obj = getattr(context, name, None)
|
||||||
|
if obj:
|
||||||
|
extra[name] = obj.state_dict()
|
||||||
|
return extra
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("progress_bar")
|
||||||
|
class ProgressBarCallback(TrainCallback):
|
||||||
|
"""
|
||||||
|
Progress bar callback for trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
||||||
|
):
|
||||||
|
self.num_epoch = num_epoch
|
||||||
|
self.log_interval = log_interval
|
||||||
|
self.file = file
|
||||||
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
|
def on_epoch_begin(self, context: TrainContext):
|
||||||
|
self.progress_bar = tqdm(
|
||||||
|
context.dataloader,
|
||||||
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
|
dynamic_ncols=True,
|
||||||
|
file=self.file or sys.stdout,
|
||||||
|
)
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
|
def on_batch_end(self, context: TrainContext):
|
||||||
|
postfix = {
|
||||||
|
"loss": f"{context.loss:.4f}",
|
||||||
|
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||||
|
}
|
||||||
|
if context.val_loss > 0:
|
||||||
|
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
||||||
|
self.progress_bar.set_postfix(postfix)
|
||||||
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
|
def on_epoch_end(self, context: TrainContext):
|
||||||
|
_ = context
|
||||||
|
if self.progress_bar:
|
||||||
|
self.progress_bar.close()
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("metric_logger")
|
||||||
|
class MetricLoggerCallback(TrainCallback):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_dir: str,
|
||||||
|
save_interval: int,
|
||||||
|
log_interval: int = 10,
|
||||||
|
metrics: List[str] = None,
|
||||||
|
):
|
||||||
|
self.last_log_iter = 0
|
||||||
|
self.save_interval = save_interval
|
||||||
|
self.log_interval = log_interval
|
||||||
|
self.metrics = metrics or ["loss", "lr"]
|
||||||
|
|
||||||
|
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
||||||
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.log_cache = []
|
||||||
|
|
||||||
|
self._metric_funcs = {
|
||||||
|
"loss": ctx_get_loss,
|
||||||
|
"lr": ctx_get_lr,
|
||||||
|
"val_loss": ctx_get_val_loss,
|
||||||
|
"grad_norm": ctx_get_grad_norm,
|
||||||
|
"grad_std": ctx_get_grad_std,
|
||||||
|
"grad_max": ctx_get_grad_max,
|
||||||
|
"grad_min": ctx_get_grad_min,
|
||||||
|
"grad_mean": ctx_get_grad_mean,
|
||||||
|
"grad_nan_num": ctx_get_grad_nan_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_log_data(self, context: TrainContext):
|
||||||
|
return {
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
|
"epoch": context.epoch,
|
||||||
|
"iter": context.iteration,
|
||||||
|
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||||
|
}
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
|
def _add_log(self, log_data):
|
||||||
|
self.log_cache.append(log_data)
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
|
def _save_log(self, epoch, iter):
|
||||||
|
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
||||||
|
|
||||||
|
with open(log_file, "w") as f:
|
||||||
|
for log in self.log_cache:
|
||||||
|
f.write(json.dumps(log) + "\n")
|
||||||
|
|
||||||
|
def on_batch_end(self, context):
|
||||||
|
if context.iteration % self.log_interval == 0:
|
||||||
|
log_data = self._get_log_data(context)
|
||||||
|
self._add_log(log_data)
|
||||||
|
|
||||||
|
if context.iteration - self.last_log_iter >= self.save_interval:
|
||||||
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
self.last_log_iter = context.iteration
|
||||||
|
|
||||||
|
def on_train_end(self, context):
|
||||||
|
if context.iteration != self.last_log_iter:
|
||||||
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
|
def on_error(self, context):
|
||||||
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("validation")
|
||||||
|
class ValidationCallback(TrainCallback):
|
||||||
|
def _run_validation(self, context: TrainContext):
|
||||||
|
context.model.eval()
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in context.val_dataloader:
|
||||||
|
loss = context.strategy(batch)
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
|
if context.world_size > 1 and dist.is_initialized():
|
||||||
|
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
|
||||||
|
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
||||||
|
avg_loss = loss_tensor.item()
|
||||||
|
|
||||||
|
context.val_loss = avg_loss
|
||||||
|
context.model.train()
|
||||||
|
|
||||||
|
step_count = context.iteration // context.config.grad_accum_steps
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
|
if context.val_dataloader is None:
|
||||||
|
return
|
||||||
|
cfg = context.config
|
||||||
|
if cfg.val_step <= 0:
|
||||||
|
return
|
||||||
|
step_count = context.iteration // cfg.grad_accum_steps
|
||||||
|
if step_count % cfg.val_step == 0:
|
||||||
|
self._run_validation(context)
|
||||||
|
|
@ -0,0 +1,170 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from astrai.config.train_config import TrainConfig
|
||||||
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
|
from astrai.model.components.lora import inject_lora
|
||||||
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||||
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||||
|
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
||||||
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainContext:
|
||||||
|
model: nn.Module = field(default=None)
|
||||||
|
strategy: BaseStrategy = field(default=None)
|
||||||
|
dataloader: DataLoader = field(default=None)
|
||||||
|
optimizer: OptimizerProtocol = field(default=None)
|
||||||
|
scheduler: SchedulerProtocol = field(default=None)
|
||||||
|
checkpoint: Checkpoint = field(default=None)
|
||||||
|
config: TrainConfig = field(default=None)
|
||||||
|
model_config: dict = field(default_factory=dict)
|
||||||
|
executor: BaseExecutor = field(default=None)
|
||||||
|
|
||||||
|
epoch: int = field(default=0)
|
||||||
|
iteration: int = field(default=0)
|
||||||
|
loss: float = field(default=0.0)
|
||||||
|
val_dataloader: DataLoader = field(default=None)
|
||||||
|
val_loss: float = field(default=0.0)
|
||||||
|
|
||||||
|
world_size: int = field(default=1)
|
||||||
|
rank: int = field(default=0)
|
||||||
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainContextBuilder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TrainConfig,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self._resume_dir: Optional[str] = None
|
||||||
|
|
||||||
|
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
||||||
|
self._resume_dir = resume_dir
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> TrainContext:
|
||||||
|
cfg = self.config
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
executor = ExecutorFactory.create(
|
||||||
|
cfg.parallel_mode,
|
||||||
|
grad_accum_steps=cfg.grad_accum_steps,
|
||||||
|
**cfg.executor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = cfg.model_fn()
|
||||||
|
model = model.to(device=device)
|
||||||
|
|
||||||
|
model_config = {}
|
||||||
|
if self._resume_dir:
|
||||||
|
config_path = Path(self._resume_dir) / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
model_config = load_json(config_path)
|
||||||
|
|
||||||
|
if not model_config and hasattr(model, "config"):
|
||||||
|
model_config = model.config.to_dict()
|
||||||
|
|
||||||
|
context = TrainContext(
|
||||||
|
model=model,
|
||||||
|
world_size=get_world_size(),
|
||||||
|
rank=get_rank(),
|
||||||
|
config=cfg,
|
||||||
|
model_config=model_config,
|
||||||
|
executor=executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._resume_dir is not None:
|
||||||
|
resume_path = Path(self._resume_dir)
|
||||||
|
if (resume_path / "meta.json").exists():
|
||||||
|
checkpoint = Checkpoint.load(self._resume_dir)
|
||||||
|
state_dict = checkpoint.state_dict
|
||||||
|
if checkpoint.config:
|
||||||
|
context.model_config = checkpoint.config
|
||||||
|
else:
|
||||||
|
checkpoint = None
|
||||||
|
state_dict = load_model_weights(self._resume_dir)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
if checkpoint is not None:
|
||||||
|
context.epoch = cfg.start_epoch
|
||||||
|
context.iteration = cfg.start_batch
|
||||||
|
context.checkpoint = checkpoint
|
||||||
|
|
||||||
|
if cfg.lora is not None:
|
||||||
|
inject_lora(
|
||||||
|
model,
|
||||||
|
r=cfg.lora.r,
|
||||||
|
alpha=cfg.lora.alpha,
|
||||||
|
target_modules=set(cfg.lora.target_modules),
|
||||||
|
)
|
||||||
|
|
||||||
|
context.optimizer = cfg.optimizer_fn(model)
|
||||||
|
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
|
sampler_offset = context.iteration * cfg.batch_per_device
|
||||||
|
sampler = ResumableDistributedSampler(
|
||||||
|
data_source=cfg.dataset,
|
||||||
|
start_epoch=context.epoch,
|
||||||
|
start_iter=sampler_offset,
|
||||||
|
seed=cfg.random_seed,
|
||||||
|
)
|
||||||
|
context.dataloader = DataLoader(
|
||||||
|
cfg.dataset,
|
||||||
|
batch_size=cfg.batch_per_device,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=cfg.pin_memory,
|
||||||
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.val_dataset is not None:
|
||||||
|
val_sampler = ResumableDistributedSampler(
|
||||||
|
data_source=cfg.val_dataset,
|
||||||
|
start_epoch=0,
|
||||||
|
start_iter=0,
|
||||||
|
seed=cfg.random_seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
context.val_dataloader = DataLoader(
|
||||||
|
cfg.val_dataset,
|
||||||
|
batch_size=cfg.batch_per_device,
|
||||||
|
sampler=val_sampler,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=cfg.pin_memory,
|
||||||
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||||
|
executor.prepare(
|
||||||
|
model,
|
||||||
|
context.optimizer,
|
||||||
|
context.dataloader,
|
||||||
|
context.scheduler,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if context.checkpoint and context.checkpoint.extra:
|
||||||
|
extra = context.checkpoint.extra
|
||||||
|
for name in ("optimizer", "scheduler"):
|
||||||
|
if name in extra:
|
||||||
|
obj = getattr(context, name, None)
|
||||||
|
if obj is not None:
|
||||||
|
obj.load_state_dict(extra[name])
|
||||||
|
|
||||||
|
context.strategy = StrategyFactory.create(
|
||||||
|
model=context.model,
|
||||||
|
train_type=cfg.strategy,
|
||||||
|
device=device,
|
||||||
|
executor=executor,
|
||||||
|
model_fn=cfg.model_fn,
|
||||||
|
**cfg.extra_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from astrai.config import TrainConfig
|
||||||
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
|
from astrai.trainer.train_callback import (
|
||||||
|
CallbackFactory,
|
||||||
|
TrainCallback,
|
||||||
|
)
|
||||||
|
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(
|
||||||
|
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
|
||||||
|
):
|
||||||
|
self.train_config = train_config
|
||||||
|
default_callbacks = self._get_default_callbacks()
|
||||||
|
self.callbacks = (
|
||||||
|
default_callbacks + callbacks if callbacks else default_callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
|
cfg = self.train_config
|
||||||
|
callbacks = [
|
||||||
|
CallbackFactory.create(
|
||||||
|
"gradient_checkpointing",
|
||||||
|
modules=cfg.gradient_checkpointing_modules,
|
||||||
|
),
|
||||||
|
CallbackFactory.create(
|
||||||
|
"checkpoint",
|
||||||
|
cfg.ckpt_dir,
|
||||||
|
cfg.ckpt_interval,
|
||||||
|
),
|
||||||
|
CallbackFactory.create(
|
||||||
|
"metric_logger",
|
||||||
|
log_dir=cfg.log_dir,
|
||||||
|
save_interval=cfg.ckpt_interval,
|
||||||
|
log_interval=cfg.log_interval,
|
||||||
|
metrics=cfg.metrics,
|
||||||
|
),
|
||||||
|
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||||
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
|
CallbackFactory.create("validation"),
|
||||||
|
]
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
|
for callback in self.callbacks:
|
||||||
|
method = getattr(callback, method_name, None)
|
||||||
|
if method:
|
||||||
|
method(context)
|
||||||
|
|
||||||
|
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
||||||
|
context = (
|
||||||
|
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
||||||
|
)
|
||||||
|
executor = context.executor
|
||||||
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
|
try:
|
||||||
|
context.model.train()
|
||||||
|
|
||||||
|
for epoch in range(context.epoch, context.config.n_epoch):
|
||||||
|
context.epoch = epoch
|
||||||
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
|
for batch in context.dataloader:
|
||||||
|
self._call_callbacks("on_batch_begin", context)
|
||||||
|
|
||||||
|
with executor.accumulate(context.model):
|
||||||
|
loss = context.strategy(batch)
|
||||||
|
context.loss = loss.item()
|
||||||
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
|
executor.backward(stand_loss)
|
||||||
|
context.iteration += 1
|
||||||
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
|
if executor.sync_gradients:
|
||||||
|
self._call_callbacks("on_optimizer_step", context)
|
||||||
|
context.optimizer.step()
|
||||||
|
context.optimizer.zero_grad()
|
||||||
|
|
||||||
|
if context.scheduler:
|
||||||
|
context.scheduler.step()
|
||||||
|
|
||||||
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Training failed: %s", str(e), exc_info=True)
|
||||||
|
self._call_callbacks("on_error", context)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._call_callbacks("on_train_end", context)
|
||||||
|
|
||||||
|
def train(self, resume_dir: Optional[str] = None):
|
||||||
|
cfg = self.train_config
|
||||||
|
spawn_parallel_fn(
|
||||||
|
self._trainer_loop,
|
||||||
|
backend=cfg.backend,
|
||||||
|
world_size=cfg.nprocs,
|
||||||
|
master_addr=cfg.master_addr,
|
||||||
|
master_port=cfg.master_port,
|
||||||
|
device_type=cfg.device_type,
|
||||||
|
start_method=cfg.start_method,
|
||||||
|
resume_dir=resume_dir,
|
||||||
|
)
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
import os
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="ViperEk/KHAOSZ",
|
|
||||||
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
|
||||||
force_download=True
|
|
||||||
)
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def generate_text():
|
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
query = input(">> ")
|
|
||||||
|
|
||||||
response = model.text_generate(
|
|
||||||
query=query,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
generate_text()
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def batch_generate():
|
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
|
||||||
|
|
||||||
responses = model.batch_generate(
|
|
||||||
queries=inputs,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
)
|
|
||||||
|
|
||||||
for q, r in zip(inputs, responses):
|
|
||||||
print((q, r))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
batch_generate()
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz, SemanticTextSplitter, Retriever
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
|
||||||
context_path = os.path.join(PROJECT_ROOT, "README.md")
|
|
||||||
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
spliter = SemanticTextSplitter(model.encode)
|
|
||||||
retriever = Retriever()
|
|
||||||
text = open(context_path, "r", encoding="utf-8").read()
|
|
||||||
|
|
||||||
res = spliter.split(text, threshold=0.8, window_size=1)
|
|
||||||
# print(("\n" + "+"*100 + "\n").join(res))
|
|
||||||
|
|
||||||
res_embs = model.encode(res)
|
|
||||||
for sentence, emb in zip(res, res_embs):
|
|
||||||
retriever.add_vector(sentence, emb)
|
|
||||||
|
|
||||||
retrive_top_k = 5
|
|
||||||
query = "作者设计了一个怎样的模型"
|
|
||||||
emb_query = model.encode(query)
|
|
||||||
retrieved = retriever.retrieve(emb_query, retrive_top_k)
|
|
||||||
|
|
||||||
retrive_response = model.retrieve_generate(
|
|
||||||
retrieved=retrieved,
|
|
||||||
query=query,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
)
|
|
||||||
|
|
||||||
print("retrieve content:")
|
|
||||||
print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]))
|
|
||||||
|
|
||||||
print("\n\nretrive generate:")
|
|
||||||
print(retrive_response)
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def chat():
|
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
history = []
|
|
||||||
while True:
|
|
||||||
query = input(">> ")
|
|
||||||
if query == "!exit":
|
|
||||||
break
|
|
||||||
|
|
||||||
response_size = 0
|
|
||||||
for response, history in model.stream_generate(
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
):
|
|
||||||
print(response[response_size:], end="", flush=True)
|
|
||||||
response_size = len(response)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
chat()
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
services:
|
||||||
|
server:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
user: "${UID:-1000}:${GID:-1000}"
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- ./params:/app/params:ro
|
||||||
|
command: python -m scripts.tools.server --port 8000 --device cuda
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 60s
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
server-cpu:
|
||||||
|
profiles: [cpu]
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
user: "${UID:-1000}:${GID:-1000}"
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- ./params:/app/params:ro
|
||||||
|
command: python -m scripts.tools.server --port 8000 --device cpu
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 120s
|
||||||
|
restart: unless-stopped
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
__version__ = "1.3.2"
|
|
||||||
__author__ = "ViperEkura"
|
|
||||||
|
|
||||||
from khaosz.api import Khaosz
|
|
||||||
from khaosz.config import (
|
|
||||||
ModelConfig,
|
|
||||||
TrainConfig,
|
|
||||||
)
|
|
||||||
from khaosz.model.transformer import Transformer
|
|
||||||
from khaosz.utils.retriever import Retriever
|
|
||||||
from khaosz.utils.splitter import (
|
|
||||||
SemanticTextSplitter,
|
|
||||||
PriorityTextSplitter
|
|
||||||
)
|
|
||||||
from khaosz.data import (
|
|
||||||
DatasetLoader,
|
|
||||||
BpeTokenizer
|
|
||||||
)
|
|
||||||
from khaosz.inference.generator import (
|
|
||||||
TextGenerator,
|
|
||||||
ChatGenerator,
|
|
||||||
StreamGenerator,
|
|
||||||
BatchGenerator,
|
|
||||||
RetrievalGenerator,
|
|
||||||
EmbeddingEncoder
|
|
||||||
)
|
|
||||||
|
|
||||||
from khaosz.trainer import (
|
|
||||||
Trainer,
|
|
||||||
StrategyFactory,
|
|
||||||
SchedulerFactory
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Khaosz",
|
|
||||||
|
|
||||||
"Transformer",
|
|
||||||
|
|
||||||
"Retriever",
|
|
||||||
"SemanticTextSplitter",
|
|
||||||
"PriorityTextSplitter",
|
|
||||||
|
|
||||||
"ModelConfig",
|
|
||||||
"TrainConfig",
|
|
||||||
|
|
||||||
"DatasetLoader",
|
|
||||||
"BpeTokenizer",
|
|
||||||
|
|
||||||
"TextGenerator",
|
|
||||||
"ChatGenerator",
|
|
||||||
"StreamGenerator",
|
|
||||||
"BatchGenerator",
|
|
||||||
"RetrievalGenerator",
|
|
||||||
"EmbeddingEncoder",
|
|
||||||
|
|
||||||
"Trainer",
|
|
||||||
"StrategyFactory",
|
|
||||||
"SchedulerFactory"
|
|
||||||
]
|
|
||||||
113
khaosz/api.py
113
khaosz/api.py
|
|
@ -1,113 +0,0 @@
|
||||||
from torch import Tensor
|
|
||||||
from typing import List, Tuple, Generator, Union
|
|
||||||
|
|
||||||
from khaosz.inference.generator import (
|
|
||||||
TextGenerator,
|
|
||||||
ChatGenerator,
|
|
||||||
StreamGenerator,
|
|
||||||
BatchGenerator,
|
|
||||||
RetrievalGenerator,
|
|
||||||
EmbeddingEncoder
|
|
||||||
)
|
|
||||||
from khaosz.config.param_config import ModelParameter
|
|
||||||
|
|
||||||
|
|
||||||
class Khaosz:
|
|
||||||
def __init__(self, model_dir: str):
|
|
||||||
self.parameter = ModelParameter()
|
|
||||||
self.parameter.load(model_dir)
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
self.parameter.to(*args, **kwargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
history: List[Tuple[str, str]]=None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> str:
|
|
||||||
generator = ChatGenerator(self.parameter)
|
|
||||||
return generator.generate(
|
|
||||||
query,
|
|
||||||
history=history,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
)
|
|
||||||
|
|
||||||
def batch_generate(
|
|
||||||
self,
|
|
||||||
queries: List[str],
|
|
||||||
histories: List[Tuple[str, str]]=None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> List[str]:
|
|
||||||
generator = BatchGenerator(self.parameter)
|
|
||||||
return generator.generate(
|
|
||||||
queries,
|
|
||||||
histories=histories,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def stream_generate(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
history: List[Tuple[str, str]]=None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
|
||||||
stream_generator = StreamGenerator(self.parameter)
|
|
||||||
return stream_generator.generate(
|
|
||||||
query,
|
|
||||||
history=history,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve_generate(
|
|
||||||
self,
|
|
||||||
retrieved,
|
|
||||||
query: str,
|
|
||||||
history: List[Tuple[str, str]] = None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> str:
|
|
||||||
generator = RetrievalGenerator(self.parameter)
|
|
||||||
return generator.generate(
|
|
||||||
retrieved,
|
|
||||||
query,
|
|
||||||
history=history,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
)
|
|
||||||
|
|
||||||
def text_generate(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> str:
|
|
||||||
generator = TextGenerator(self.parameter)
|
|
||||||
|
|
||||||
return generator.generate(
|
|
||||||
query,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
|
||||||
encoder = EmbeddingEncoder(self.parameter)
|
|
||||||
return encoder.encode(sentence)
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
from khaosz.config.model_config import ModelConfig
|
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
|
||||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
|
||||||
from khaosz.config.train_config import TrainConfig
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseModelIO",
|
|
||||||
"ModelParameter",
|
|
||||||
"ModelConfig",
|
|
||||||
"TrainConfig",
|
|
||||||
|
|
||||||
"ScheduleConfig",
|
|
||||||
"CosineScheduleConfig",
|
|
||||||
"SGDRScheduleConfig",
|
|
||||||
]
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
import json
|
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from typing import Optional, Self
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelConfig:
|
|
||||||
# basic config
|
|
||||||
vocab_size: Optional[int] = None
|
|
||||||
dim: Optional[int] = None
|
|
||||||
|
|
||||||
n_layers: Optional[int] = None
|
|
||||||
norm_eps: Optional[float] = None
|
|
||||||
dim_ffn: Optional[int] = None
|
|
||||||
tie_weight: Optional[bool] = None
|
|
||||||
|
|
||||||
# RoPE
|
|
||||||
max_len: Optional[int] = None
|
|
||||||
rope_theta: Optional[float] = None
|
|
||||||
|
|
||||||
# GQA
|
|
||||||
n_heads: Optional[int] = None
|
|
||||||
n_kv_heads: Optional[int] = None
|
|
||||||
use_qk_norm: Optional[bool] = None
|
|
||||||
use_gated_attention: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
|
||||||
config = {}
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
config.update(json.load(f))
|
|
||||||
|
|
||||||
for key, value in config.items():
|
|
||||||
if hasattr(self, key):
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def save(self, config_path: str):
|
|
||||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
with open(config_path, 'w') as f:
|
|
||||||
json.dump(config_dict, f, indent=4)
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
import torch.nn as nn
|
|
||||||
import safetensors.torch as st
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional, Self, Union
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
|
||||||
from khaosz.config.model_config import ModelConfig
|
|
||||||
from khaosz.model.transformer import Transformer
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseModelIO:
|
|
||||||
"""Base class for model I/O operations."""
|
|
||||||
|
|
||||||
model: Optional[nn.Module] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Transformer model."}
|
|
||||||
)
|
|
||||||
tokenizer: BpeTokenizer = field(
|
|
||||||
default_factory=BpeTokenizer,
|
|
||||||
metadata={"help": "Tokenizer for the model."}
|
|
||||||
)
|
|
||||||
config: ModelConfig = field(
|
|
||||||
default_factory=ModelConfig,
|
|
||||||
metadata={"help": "Transformer model configuration."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
|
||||||
"""Get standardized file paths for model components."""
|
|
||||||
dir_path = Path(directory)
|
|
||||||
return {
|
|
||||||
"model": dir_path / "model.safetensors",
|
|
||||||
"config": dir_path / "config.json",
|
|
||||||
"tokenizer": dir_path / "tokenizer.json"
|
|
||||||
}
|
|
||||||
|
|
||||||
def save_components(self, save_dir: Union[str, Path]):
|
|
||||||
"""Save core model components."""
|
|
||||||
paths = self._get_file_paths(save_dir)
|
|
||||||
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
if self.model is not None:
|
|
||||||
st.save_file(self.model.state_dict(), str(paths["model"]))
|
|
||||||
self.config.save(str(paths["config"]))
|
|
||||||
self.tokenizer.save(str(paths["tokenizer"]))
|
|
||||||
|
|
||||||
def load_components(self, load_dir: Union[str, Path]) -> Self:
|
|
||||||
"""Load core model components."""
|
|
||||||
paths = self._get_file_paths(load_dir)
|
|
||||||
|
|
||||||
self.config.load(str(paths["config"]))
|
|
||||||
self.tokenizer.load(str(paths["tokenizer"]))
|
|
||||||
|
|
||||||
if self.model is None:
|
|
||||||
self.model = Transformer(self.config)
|
|
||||||
|
|
||||||
if paths["model"].exists():
|
|
||||||
state_dict = st.load_file(str(paths["model"]))
|
|
||||||
self.model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> "BaseModelIO":
|
|
||||||
"""Move model to device."""
|
|
||||||
if self.model is not None:
|
|
||||||
self.model.to(*args, **kwargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelParameter(BaseModelIO):
|
|
||||||
"""Container for model parameters with serialization capabilities."""
|
|
||||||
|
|
||||||
def save(self, save_dir: Union[str, Path]):
|
|
||||||
self.save_components(save_dir)
|
|
||||||
|
|
||||||
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
|
|
||||||
return self.load_components(load_dir)
|
|
||||||
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
from typing import Any, Dict
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ScheduleConfig(ABC):
|
|
||||||
schedule_type: str = field(
|
|
||||||
default="cosine",
|
|
||||||
metadata={
|
|
||||||
"help": "Type of learning rate schedule.",
|
|
||||||
"choices": ["cosine", "sgdr"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
warmup_steps: int = field(
|
|
||||||
default=1000,
|
|
||||||
metadata={"help": "Number of warmup steps."}
|
|
||||||
)
|
|
||||||
min_rate: float = field(
|
|
||||||
default=0.05,
|
|
||||||
metadata={"help": "Minimum learning rate multiplier."}
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
"""Validate configuration parameters."""
|
|
||||||
if self.warmup_steps < 0:
|
|
||||||
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
|
||||||
if not 0 <= self.min_rate <= 1:
|
|
||||||
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CosineScheduleConfig(ScheduleConfig):
|
|
||||||
total_steps: int = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Total training steps for cosine schedule."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
self.schedule_type = "cosine"
|
|
||||||
self.validate()
|
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
|
||||||
if self.total_steps is None:
|
|
||||||
raise ValueError("total_steps must be specified for cosine schedule")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"schedule_type": self.schedule_type,
|
|
||||||
"warmup_steps": self.warmup_steps,
|
|
||||||
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
|
||||||
"min_rate": self.min_rate
|
|
||||||
}
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
super().validate()
|
|
||||||
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
|
||||||
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SGDRScheduleConfig(ScheduleConfig):
|
|
||||||
cycle_length: int = field(
|
|
||||||
default=1000,
|
|
||||||
metadata={"help": "Length of the first cycle in steps."}
|
|
||||||
)
|
|
||||||
t_mult: int = field(
|
|
||||||
default=2,
|
|
||||||
metadata={"help": "Multiplier for cycle length growth."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
self.schedule_type = "sgdr"
|
|
||||||
self.validate()
|
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"schedule_type": self.schedule_type,
|
|
||||||
"warmup_steps": self.warmup_steps,
|
|
||||||
"cycle_length": self.cycle_length,
|
|
||||||
"min_rate": self.min_rate,
|
|
||||||
"t_mult": self.t_mult
|
|
||||||
}
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
super().validate()
|
|
||||||
if self.cycle_length <= 0:
|
|
||||||
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
|
||||||
if self.t_mult < 1:
|
|
||||||
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
|
||||||
|
|
@ -1,136 +0,0 @@
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Callable, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainConfig:
|
|
||||||
# basic setting
|
|
||||||
model: nn.Module = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Model for training."}
|
|
||||||
)
|
|
||||||
strategy: str = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Training strategy."}
|
|
||||||
)
|
|
||||||
dataset: Dataset = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Dataset for training."}
|
|
||||||
)
|
|
||||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Optimizer factory for training."}
|
|
||||||
)
|
|
||||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Scheduler factory for training."}
|
|
||||||
)
|
|
||||||
n_epoch: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of epochs for training."}
|
|
||||||
)
|
|
||||||
batch_size: int = field(
|
|
||||||
default=4,
|
|
||||||
metadata={"help": "Batch size for training."}
|
|
||||||
)
|
|
||||||
accumulation_steps: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of iterations between steps."}
|
|
||||||
)
|
|
||||||
max_grad_norm: float = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "Maximum gradient norm."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# checkpoint setting
|
|
||||||
start_epoch: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Start epoch for training."}
|
|
||||||
)
|
|
||||||
start_batch: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Start batch iteration for training."}
|
|
||||||
)
|
|
||||||
checkpoint_dir: str = field(
|
|
||||||
default="./checkpoint",
|
|
||||||
metadata={"help": "Checkpoint directory."}
|
|
||||||
)
|
|
||||||
checkpoint_interval: int = field(
|
|
||||||
default=5000,
|
|
||||||
metadata={"help": "Number of iterations between checkpoints."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# dataloader setting
|
|
||||||
random_seed: int = field(
|
|
||||||
default=3407,
|
|
||||||
metadata={"help": "Random seed."}
|
|
||||||
)
|
|
||||||
num_workers: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Number of workers for dataloader."}
|
|
||||||
)
|
|
||||||
prefetch_factor: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Prefetch factor for dataloader."}
|
|
||||||
)
|
|
||||||
pin_memory: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Pin memory for dataloader."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# distributed training
|
|
||||||
nprocs: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of processes for distributed training."}
|
|
||||||
)
|
|
||||||
backend: str = field(
|
|
||||||
default="nccl",
|
|
||||||
metadata={"help": "Distributed training backend."}
|
|
||||||
)
|
|
||||||
master_addr: str = field(
|
|
||||||
default="localhost",
|
|
||||||
metadata={"help": "Master address for distributed training."}
|
|
||||||
)
|
|
||||||
master_port: str = field(
|
|
||||||
default="29500",
|
|
||||||
metadata={"help": "Master port for distributed training."}
|
|
||||||
)
|
|
||||||
parallel_wrapper: Optional[Callable] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Parallel function for training."}
|
|
||||||
)
|
|
||||||
state_dict_fn: Optional[Callable] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Parallel function for state dict saving."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# others
|
|
||||||
device_ids: Optional[List[int]] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Device ids for distributed training."}
|
|
||||||
)
|
|
||||||
device_type: str = field(
|
|
||||||
default="cuda",
|
|
||||||
metadata={"help": "Device type for distributed training."}
|
|
||||||
)
|
|
||||||
extra_kwargs: dict = field(
|
|
||||||
default_factory=dict,
|
|
||||||
metadata={"help": "Other arguments."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.validate()
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
|
|
||||||
|
|
||||||
for field_name in required_fields:
|
|
||||||
if getattr(self, field_name) is None:
|
|
||||||
raise ValueError(f"{field_name} is required.")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
||||||
from khaosz.data.dataset import (
|
|
||||||
BaseDataset,
|
|
||||||
SeqDataset,
|
|
||||||
DpoDataset,
|
|
||||||
SftDataset,
|
|
||||||
PpoDataset,
|
|
||||||
MultiSegmentFetcher,
|
|
||||||
DatasetLoader
|
|
||||||
)
|
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
|
||||||
from khaosz.data.sampler import ResumableDistributedSampler
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseDataset",
|
|
||||||
"SeqDataset",
|
|
||||||
"DpoDataset",
|
|
||||||
"SftDataset",
|
|
||||||
"PpoDataset",
|
|
||||||
"MultiSegmentFetcher",
|
|
||||||
"DatasetLoader",
|
|
||||||
"BpeTokenizer",
|
|
||||||
"ResumableDistributedSampler"
|
|
||||||
]
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Any
|
|
||||||
from khaosz.parallel.setup import get_rank
|
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict: Dict[str, Any],
|
|
||||||
epoch: int = 0,
|
|
||||||
iteration: int = 0,
|
|
||||||
):
|
|
||||||
self.state_dict = state_dict
|
|
||||||
self.epoch = epoch
|
|
||||||
self.iteration = iteration
|
|
||||||
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
save_dir: str,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
save_path = Path(save_dir)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
if rank == 0:
|
|
||||||
meta = {
|
|
||||||
"epoch": self.epoch,
|
|
||||||
"iteration": self.iteration,
|
|
||||||
}
|
|
||||||
with open(save_path / "meta.json", "w") as f:
|
|
||||||
json.dump(meta, f, indent=2)
|
|
||||||
|
|
||||||
with open(save_path / f"state_dict.pt", "wb") as f:
|
|
||||||
torch.save(self.state_dict, f)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
save_dir: str,
|
|
||||||
) -> "Checkpoint":
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
save_path = Path(save_dir)
|
|
||||||
|
|
||||||
meta = {}
|
|
||||||
if rank == 0:
|
|
||||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
|
||||||
meta_list = [meta]
|
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
|
||||||
meta = meta_list[0]
|
|
||||||
|
|
||||||
with open(save_path / f"state_dict.pt", "rb") as f:
|
|
||||||
state_dict = torch.load(f)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
state_dict=state_dict,
|
|
||||||
epoch=meta["epoch"],
|
|
||||||
iteration=meta["iteration"],
|
|
||||||
)
|
|
||||||
|
|
@ -1,201 +0,0 @@
|
||||||
import torch
|
|
||||||
import bisect
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from khaosz.data.file import load_h5
|
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
|
||||||
def __init__(self, segments: List[Tensor]):
|
|
||||||
self.segments = segments
|
|
||||||
self.cum_lengths = []
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
for seg in segments:
|
|
||||||
total += torch.numel(seg)
|
|
||||||
self.cum_lengths.append(total)
|
|
||||||
|
|
||||||
self.total_length = total
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.total_length
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
||||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
|
||||||
raise ValueError("begin_idx or end_idx out of bounds")
|
|
||||||
if begin_idx >= end_idx:
|
|
||||||
return torch.tensor([], dtype=torch.long)
|
|
||||||
|
|
||||||
# fix the range index bug
|
|
||||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
|
||||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
|
||||||
|
|
||||||
result_segments = []
|
|
||||||
|
|
||||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
|
||||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
|
||||||
start = max(begin_idx - prev_cum, 0)
|
|
||||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
|
||||||
data = self.segments[i][start:end]
|
|
||||||
result_segments.append(data)
|
|
||||||
|
|
||||||
return torch.cat(result_segments, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiSegmentFetcher:
|
|
||||||
def __init__(self, muti_segments: Dict):
|
|
||||||
self.muti_keys = list(muti_segments.keys())
|
|
||||||
self.muti_fetchers = {
|
|
||||||
key: BaseSegmentFetcher(segments)
|
|
||||||
for key, segments in muti_segments.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
|
||||||
return min(len_list)
|
|
||||||
|
|
||||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
|
||||||
fetch_dict = {}
|
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
fetcher = self.muti_fetchers[key]
|
|
||||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
|
||||||
fetch_dict[key] = fetch_tensor
|
|
||||||
|
|
||||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
|
||||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__()
|
|
||||||
self.segments = {}
|
|
||||||
self.window_size = window_size
|
|
||||||
self.stride = stride
|
|
||||||
self.total_samples = None
|
|
||||||
|
|
||||||
def load(self, load_path: str):
|
|
||||||
self.segments = load_h5(load_path)
|
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
self.total_samples = len(self.fetcher)
|
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
|
||||||
assert self.total_samples > self.window_size
|
|
||||||
|
|
||||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
|
||||||
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
|
||||||
|
|
||||||
return begin_idx, end_idx
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
assert self.total_samples is not None
|
|
||||||
if self.total_samples <= self.window_size:
|
|
||||||
return 0
|
|
||||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
|
||||||
|
|
||||||
|
|
||||||
class SeqDataset(BaseDataset):
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
# fix the range index bug
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y}
|
|
||||||
|
|
||||||
|
|
||||||
class SftDataset(BaseDataset):
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
|
||||||
|
|
||||||
|
|
||||||
class DpoDataset(BaseDataset):
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
|
||||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
|
||||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
|
||||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
|
||||||
|
|
||||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
|
||||||
|
|
||||||
|
|
||||||
class PpoDataset(BaseDataset):
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
|
||||||
|
|
||||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
|
||||||
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
|
||||||
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
|
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
|
||||||
|
|
||||||
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetLoader:
|
|
||||||
@staticmethod
|
|
||||||
def load(
|
|
||||||
train_type: Literal["seq", "sft", "dpo"],
|
|
||||||
load_path: str,
|
|
||||||
window_size: int,
|
|
||||||
stride: Optional[int] = None,
|
|
||||||
) -> BaseDataset:
|
|
||||||
if stride is None:
|
|
||||||
stride = window_size
|
|
||||||
|
|
||||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
|
||||||
"seq": lambda window_size: SeqDataset(window_size, stride),
|
|
||||||
"sft": lambda window_size: SftDataset(window_size, stride),
|
|
||||||
"dpo": lambda window_size: DpoDataset(window_size, stride),
|
|
||||||
}
|
|
||||||
dataset = dataset_router[train_type](window_size)
|
|
||||||
dataset.load(load_path)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue