Compare commits
350 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
01ce1fb9e3 | |
|
|
14f83cbdac | |
|
|
dbe5891201 | |
|
|
2a65c3314c | |
|
|
1c2ff05a6d | |
|
|
31ae2deeba | |
|
|
69207e2c57 | |
|
|
138c5bcc08 | |
|
|
a923e0a23a | |
|
|
f521a30b22 | |
|
|
d4451f6afb | |
|
|
a3275423a4 | |
|
|
b37c3d000c | |
|
|
6031020e37 | |
|
|
c424dfc293 | |
|
|
3a28e52e98 | |
|
|
e371908b54 | |
|
|
7c99da155c | |
|
|
629e72385b | |
|
|
0a708fff24 | |
|
|
6e150ea6d0 | |
|
|
cb8dcb97ea | |
|
|
2d5dc93b3d | |
|
|
4145d35e3c | |
|
|
34c6c45bd6 | |
|
|
e9def84ce7 | |
|
|
836e02a166 | |
|
|
b558e61f63 | |
|
|
65ab69543b | |
|
|
1d26aa2e93 | |
|
|
a548d4553e | |
|
|
dd1b39f435 | |
|
|
94d6e713e9 | |
|
|
47c37e4876 | |
|
|
737585a32a | |
|
|
a4688021bf | |
|
|
7df6eb9211 | |
|
|
82a3f2626f | |
|
|
7fa69572c0 | |
|
|
3ab4f237e5 | |
|
|
8cbf3f36e2 | |
|
|
0594ce1017 | |
|
|
ff509ff39f | |
|
|
785d65436c | |
|
|
64be81b7b3 | |
|
|
45479b5731 | |
|
|
e0a3337c22 | |
|
|
812238060b | |
|
|
14b0d56197 | |
|
|
6c8533f1d2 | |
|
|
2c2697390d | |
|
|
7621f05d3f | |
|
|
10ebd7211f | |
|
|
42a391f0fb | |
|
|
97c7ac0f4f | |
|
|
8f1b32f2b6 | |
|
|
c241a5dcef | |
|
|
44dab27fdc | |
|
|
a44fd22a99 | |
|
|
8a11a7d444 | |
|
|
1d54491809 | |
|
|
ad9f4d9cf6 | |
|
|
e1638a7ade | |
|
|
f91bfee33e | |
|
|
d7a7f570ed | |
|
|
7dea929788 | |
|
|
026d1fc33d | |
|
|
7242eedbf4 | |
|
|
04c0dc7a47 | |
|
|
48a53121ba | |
|
|
0ba8c70ce1 | |
|
|
3d12a03909 | |
|
|
c169659611 | |
|
|
e12f1a7ee5 | |
|
|
ef25efffa2 | |
|
|
19532440b4 | |
|
|
9096e413c3 | |
|
|
9d5e9fa6c4 | |
|
|
08dde46778 | |
|
|
513f1f7826 | |
|
|
e3382f6bb5 | |
|
|
f0339022c1 | |
|
|
d8da2cf17c | |
|
|
205b40bd28 | |
|
|
18fe6e9339 | |
|
|
2196c34c52 | |
|
|
466c2e1efd | |
|
|
7e26d848ab | |
|
|
ed95ef245c | |
|
|
6d6ef99e66 | |
|
|
a8e2a1ba45 | |
|
|
6269bacfc3 | |
|
|
c0effc9f5b | |
|
|
df0845e916 | |
|
|
7440e9c809 | |
|
|
7d4029c2a4 | |
|
|
0ca6c9e6eb | |
|
|
6e49d27057 | |
|
|
5203b7f53e | |
|
|
5889179c54 | |
|
|
38e18fdfd3 | |
|
|
4753958f92 | |
|
|
73d6cc0f26 | |
|
|
317ed90bac | |
|
|
951df8155c | |
|
|
a58fab8d6e | |
|
|
a3c8296135 | |
|
|
c95ace41aa | |
|
|
3da428e0e4 | |
|
|
133a9de98f | |
|
|
523eacf5fe | |
|
|
cffedaad5e | |
|
|
3583c46b66 | |
|
|
ca4e6b907c | |
|
|
db99d8b254 | |
|
|
b98c9cefdc | |
|
|
283bcaf2ff | |
|
|
bc7c82977e | |
|
|
34a511e36e | |
|
|
d73f52a2f8 | |
|
|
9d96b0431d | |
|
|
f81e2b4a73 | |
|
|
4e324d8f26 | |
|
|
6ed0506491 | |
|
|
30cc2d67a4 | |
|
|
7ddebf2cd9 | |
|
|
78dc2bd41c | |
|
|
44d7a4e959 | |
|
|
c4401512f2 | |
|
|
a6f5ff3b37 | |
|
|
ffff05b2c6 | |
|
|
b89f8436ea | |
|
|
123f25e339 | |
|
|
520de3ebe8 | |
|
|
466c34d7a8 | |
|
|
6831a15424 | |
|
|
0f9e5c5049 | |
|
|
cb0e7f2a80 | |
|
|
296db909aa | |
|
|
a2ae742988 | |
|
|
29beb174a5 | |
|
|
bbeaff4c60 | |
|
|
ab5e207f42 | |
|
|
b0eff02446 | |
|
|
408f0cb513 | |
|
|
64b78ecce3 | |
|
|
f2ffdf60d0 | |
|
|
ace8f6ee68 | |
|
|
a57a16430d | |
|
|
3fee87897d | |
|
|
3f67e53088 | |
|
|
bf7adb35b3 | |
|
|
feaa3fca36 | |
|
|
39766aa1dc | |
|
|
9b22b1651e | |
|
|
e58dbd7c57 | |
|
|
d2fe8afbd1 | |
|
|
23ce4bc3ae | |
|
|
d2b36cc85d | |
|
|
fc278d17ab | |
|
|
ff43a2fab8 | |
|
|
2b26f03bd3 | |
|
|
861d33b1a1 | |
|
|
99b821ebf5 | |
|
|
c94a246c71 | |
|
|
2dc9545d7f | |
|
|
9c31d78a22 | |
|
|
bd9741dc5f | |
|
|
b531232a9b | |
|
|
3346c75584 | |
|
|
aa5e03d7f6 | |
|
|
073baf105c | |
|
|
e97536758f | |
|
|
7861af12e4 | |
|
|
7f0552013a | |
|
|
3535de5cc4 | |
|
|
26989e54aa | |
|
|
70d52935f0 | |
|
|
c0e0e6afd9 | |
|
|
0852b852f8 | |
|
|
3a7d98a950 | |
|
|
c5560740b6 | |
|
|
94c6a015c8 | |
|
|
8b6509b305 | |
|
|
912d7c7f54 | |
|
|
475de51c7d | |
|
|
9f1561afe7 | |
|
|
80c0b20877 | |
|
|
e7721eafc6 | |
|
|
4ead0a20cf | |
|
|
b1527d9575 | |
|
|
2e009cf59a | |
|
|
780b9e1855 | |
|
|
aef7615abd | |
|
|
50488bd659 | |
|
|
eb57e55fca | |
|
|
426af2d75f | |
|
|
345fd2f091 | |
|
|
e1f9901384 | |
|
|
0e7fc623b4 | |
|
|
3e33c14376 | |
|
|
60f4df95bd | |
|
|
c01791ff54 | |
|
|
980299cd54 | |
|
|
3e8f2eba81 | |
|
|
361cdeb296 | |
|
|
50f76cd7c7 | |
|
|
0f518473af | |
|
|
a5574f92e2 | |
|
|
abcedf892e | |
|
|
abc3a06266 | |
|
|
62fba9a298 | |
|
|
e23a5ca426 | |
|
|
e55b57d771 | |
|
|
c4feab96fe | |
|
|
e35cb0d84a | |
|
|
6d6ef6dbb6 | |
|
|
493fe4e84b | |
|
|
82d22c5742 | |
|
|
96744ac2d2 | |
|
|
2331713fde | |
|
|
c74fbf84b7 | |
|
|
5a8c442315 | |
|
|
c7d0448822 | |
|
|
1d43a1785e | |
|
|
5713b55500 | |
|
|
b53e10aac4 | |
|
|
dff58468d6 | |
|
|
8a8d6369bc | |
|
|
80e17418b4 | |
|
|
6089a12cef | |
|
|
b17cc6a6fb | |
|
|
a33d086883 | |
|
|
e9f42ec8b1 | |
|
|
582d4ae9a7 | |
|
|
0ca4871e80 | |
|
|
99ef8fda71 | |
|
|
dbd57e30e5 | |
|
|
a5869d89ba | |
|
|
7a9b9d0659 | |
|
|
75758ead46 | |
|
|
7dfa5cc0ac | |
|
|
9dab96c31f | |
|
|
ff5c8a71f5 | |
|
|
4da70785b5 | |
|
|
d407962ffa | |
|
|
3d8047fa1b | |
|
|
d21682f97a | |
|
|
eba99e1f5e | |
|
|
fd7ee2895a | |
|
|
cfa3cf7daa | |
|
|
7623b1e5fd | |
|
|
573f041c51 | |
|
|
eab7a51bb6 | |
|
|
3ac38a7ebc | |
|
|
831933fb66 | |
|
|
701fb9bf78 | |
|
|
d882f65579 | |
|
|
a30ddca517 | |
|
|
8e975017d3 | |
|
|
fed4d64cea | |
|
|
110efd2a21 | |
|
|
530fb50352 | |
|
|
c86e573195 | |
|
|
0093ba7bb8 | |
|
|
c934210066 | |
|
|
c98b175cd5 | |
|
|
82e65ccc21 | |
|
|
d52685facd | |
|
|
d31137a2db | |
|
|
6270415590 | |
|
|
08c5a52dc8 | |
|
|
ac1fefb363 | |
|
|
8b20982933 | |
|
|
d5cc9f065d | |
|
|
db53cc5001 | |
|
|
3ee84b31a0 | |
|
|
567c55685e | |
|
|
1f5cba889b | |
|
|
019bfe4e05 | |
|
|
36b410384b | |
|
|
09963a3beb | |
|
|
5daf63a7a4 | |
|
|
fb85aaf6a6 | |
|
|
6fb6a15e81 | |
|
|
d9ff662e3a | |
|
|
e12ed0a72b | |
|
|
3bf2468905 | |
|
|
3c7ed84516 | |
|
|
1c3a693d79 | |
|
|
e99ef9d6d8 | |
|
|
4c289e974a | |
|
|
f31bf5a959 | |
|
|
7a21f5d72e | |
|
|
0b45e8666e | |
|
|
6f3386f02c | |
|
|
d25202a329 | |
|
|
254ec934be | |
|
|
7e5ecf3b7d | |
|
|
66a551217e | |
|
|
bdc3f4dc63 | |
|
|
805773c7fe | |
|
|
7ccc4ab9ac | |
|
|
69d9374f51 | |
|
|
b260f5581d | |
|
|
0a754e3341 | |
|
|
144b9598ad | |
|
|
877669b799 | |
|
|
cdb47a62dc | |
|
|
e86328b753 | |
|
|
5d3799b715 | |
|
|
6a3135f401 | |
|
|
12850d403c | |
|
|
bad6243b53 | |
|
|
f2448a5147 | |
|
|
46b2a0f86f | |
|
|
d94fc5a87a | |
|
|
38b2725cd1 | |
|
|
bc5ef72001 | |
|
|
e051005334 | |
|
|
0db046f8d9 | |
|
|
05b012820b | |
|
|
e72e244df6 | |
|
|
98efca7b9d | |
|
|
613edd7a14 | |
|
|
622982364b | |
|
|
b67bc9865d | |
|
|
c51b203fde | |
|
|
8434c19923 | |
|
|
68a15005cb | |
|
|
efbe3de9d3 | |
|
|
12793bc2d3 | |
|
|
0764cb8296 | |
|
|
57cd7b921e | |
|
|
c1bf22b6ec | |
|
|
f9b6331ad7 | |
|
|
183f481692 | |
|
|
ec0c054d26 | |
|
|
4ffa7454f2 | |
|
|
8c9e973179 | |
|
|
fc98d9b7e6 | |
|
|
9d5aa952e0 | |
|
|
2ccd7bd583 | |
|
|
e7d29ca2d5 | |
|
|
465a1a9373 | |
|
|
240ee00221 | |
|
|
6e1a497c04 | |
|
|
85aeec9e55 | |
|
|
9a452dd34e | |
|
|
28b01220b6 |
|
|
@ -0,0 +1,9 @@
|
|||
# Ignore everything
|
||||
*
|
||||
|
||||
# Allow necessary files
|
||||
!astrai/
|
||||
!scripts/
|
||||
!assets/
|
||||
!pyproject.toml
|
||||
!README.md
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# Auto detect text files
|
||||
* text=auto
|
||||
|
||||
# Files that MUST use LF (Unix/Linux execution)
|
||||
*.sh text eol=lf
|
||||
*.py text eol=lf
|
||||
*.md text eol=lf
|
||||
*.yml text eol=lf
|
||||
|
||||
Dockerfile text eol=lf
|
||||
.dockerignore text eol=lf
|
||||
|
||||
.gitignore text eol=lf
|
||||
.gitattributes text eol=lf
|
||||
|
||||
# Windows scripts - use CRLF
|
||||
*.bat text eol=crlf
|
||||
*.cmd text eol=crlf
|
||||
*.ps1 text eol=crlf
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: "[BUG]"
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Description
|
||||
A clear and concise description of what the bug is.
|
||||
## Steps to Reproduce
|
||||
1. ...
|
||||
2. ...
|
||||
3. ...
|
||||
## Expected Behavior
|
||||
What you expected to happen.
|
||||
## Actual Behavior
|
||||
What actually happened.
|
||||
## Environment
|
||||
- Python version:
|
||||
- AstrAI version (or commit hash):
|
||||
- Operating System:
|
||||
- GPU (if applicable):
|
||||
- CUDA/cuDNN version (if applicable):
|
||||
## Additional Context
|
||||
Add any other context, screenshots, or logs here.
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
---
|
||||
name: Custom issue template
|
||||
about: Describe this issue template's purpose here.
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: "[FEAT]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Description
|
||||
A clear and concise description of the feature you'd like to see.
|
||||
## Problem Statement
|
||||
What problem does this feature solve? Why is it needed?
|
||||
## Proposed Solution
|
||||
Describe the solution you'd like. Include any design ideas, API changes, or implementation details.
|
||||
## Alternatives Considered
|
||||
Describe any alternative solutions or features you've considered.
|
||||
## Additional Context
|
||||
Add any other context, screenshots, or references here.
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
## Description
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context.
|
||||
|
||||
Fixes # (issue number)
|
||||
|
||||
## Type of Change
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] Documentation update
|
||||
- [ ] Other (please describe):
|
||||
|
||||
## How Has This Been Tested?
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
|
||||
|
||||
## Checklist:
|
||||
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] Code is self-documenting (no unnecessary comments)
|
||||
- [ ] I have made corresponding changes to the documentation
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] New and existing unit tests pass locally with my changes
|
||||
- [ ] Any dependent changes have been merged and published in downstream modules
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
name: Build and Push Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
tags: |
|
||||
type=ref,event=tag
|
||||
type=raw,value=latest
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[dev]
|
||||
|
||||
- name: Check formatting with ruff
|
||||
run: |
|
||||
ruff format --check .
|
||||
|
||||
- name: Check import sorting
|
||||
run: |
|
||||
ruff check . --select I
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[dev]
|
||||
|
||||
- name: Run tests with pytest
|
||||
run: |
|
||||
python -m pytest tests/ -v
|
||||
|
|
@ -1,13 +1,23 @@
|
|||
# cache
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
# Ignore everything
|
||||
*
|
||||
|
||||
# params
|
||||
params/*
|
||||
# Allow directories to be traversed
|
||||
!*/
|
||||
|
||||
# vscode file
|
||||
.vscode
|
||||
# Allow specific file types and root files
|
||||
!*.py
|
||||
!*.sh
|
||||
|
||||
# build file
|
||||
build
|
||||
*.egg-info
|
||||
# Allow GitHub files
|
||||
!/.github/**
|
||||
|
||||
# Allow root files
|
||||
!/.gitattributes
|
||||
!/.dockerignore
|
||||
!/Dockerfile
|
||||
!/docker-compose.yml
|
||||
!/assets/**
|
||||
!/CONTRIBUTING.md
|
||||
!/LICENSE
|
||||
!/pyproject.toml
|
||||
!/README.md
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
# Contributing to AstrAI
|
||||
|
||||
Thank you for your interest in contributing! This document provides step-by-step guidelines.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-username/AstrAI.git
|
||||
cd AstrAI
|
||||
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
||||
```
|
||||
|
||||
## Before You Commit
|
||||
|
||||
Run the following checks **in order** — CI will reject if any fail.
|
||||
|
||||
### 1. Format
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
```
|
||||
|
||||
> **Note**: `ruff format` may rename parameters (e.g. `mask` → `attn_mask`).
|
||||
> Always review the diff after formatting.
|
||||
|
||||
### 2. Import sorting
|
||||
|
||||
```bash
|
||||
ruff check . --select I
|
||||
```
|
||||
|
||||
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
|
||||
|
||||
```bash
|
||||
ruff check . --select I --fix .
|
||||
ruff format . # re-format after fix
|
||||
```
|
||||
|
||||
### 3. Run tests
|
||||
|
||||
```bash
|
||||
python -u -m pytest tests/ -v
|
||||
```
|
||||
|
||||
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
|
||||
|
||||
### 4. (Optional) Full pre-commit check
|
||||
|
||||
If you have Git Bash available:
|
||||
|
||||
```bash
|
||||
bash scripts/pre_commit.sh
|
||||
```
|
||||
|
||||
This runs format check, import sort check, and tests in one go.
|
||||
|
||||
## Commit Style
|
||||
|
||||
```
|
||||
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
|
||||
|
||||
- bullet point body (each ~60 chars)
|
||||
```
|
||||
|
||||
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
|
||||
- **Subject line** ends with no period.
|
||||
- **Body** uses bullet points starting with `-`.
|
||||
- No `(scope)` parentheses.
|
||||
|
||||
## Common Issues
|
||||
|
||||
| Problem | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
|
||||
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
|
||||
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
|
||||
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
|
||||
|
||||
## Submitting Changes
|
||||
|
||||
1. Fork the repo.
|
||||
2. Create a feature branch: `git checkout -b feat/my-feature`
|
||||
3. Make changes following the steps above.
|
||||
4. Commit with the commit style above.
|
||||
5. Push: `git push origin feat/my-feature`
|
||||
6. Open a Pull Request against `main`.
|
||||
|
||||
## Code Review
|
||||
|
||||
- All PRs are reviewed. We may request changes.
|
||||
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
|
||||
- Ensure all tests pass.
|
||||
|
||||
## License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
|
||||
|
||||
---
|
||||
|
||||
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# AstrAI Dockerfile - Multi-stage Build (Optimized)
|
||||
|
||||
# Build stage - use base image with minimal build tools
|
||||
FROM ubuntu:24.04 AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python 3.12 and minimal build dependencies
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
python3.12-dev \
|
||||
python3.12-venv \
|
||||
gcc \
|
||||
g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create isolated virtual environment
|
||||
RUN python3.12 -m venv --copies /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Copy source code and install (deps read from pyproject.toml)
|
||||
COPY astrai/ ./astrai/
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir --upgrade pip \
|
||||
&& pip install --no-cache-dir . \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu126
|
||||
|
||||
# Production stage
|
||||
FROM ubuntu:24.04 AS production
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python 3.12 runtime and healthcheck dependency
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy virtual environment from builder
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Copy application code
|
||||
COPY astrai/ ./astrai/
|
||||
COPY scripts/ ./scripts/
|
||||
COPY assets/ ./assets/
|
||||
COPY pyproject.toml .
|
||||
COPY README.md .
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m astrai && chown -R astrai:astrai /app
|
||||
USER astrai
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1
|
||||
811
LICENSE
811
LICENSE
|
|
@ -1,201 +1,674 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
1. Definitions.
|
||||
Preamble
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
0. Definitions.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
1. Source Code.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
2. Basic Permissions.
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
488
README.md
488
README.md
|
|
@ -1,333 +1,255 @@
|
|||

|
||||
|
||||
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; text-align: center; font-size: 16px; font-weight: bold; margin-top: 50px;">
|
||||
<div align="center">
|
||||
|
||||
<div>
|
||||
<a href="#english" style="text-decoration: none; margin: 0 10px; color: blue;">English</a> |
|
||||
<a href="#chinese" style="text-decoration: none; margin: 0 10px; color: blue;">中文</a>
|
||||
</div>
|
||||
|
||||
<h1 style="margin: 20px 0 0 0; font-size: 2.5em; font-weight: bold;">KHAOSZ </h1>
|
||||
<img src="assets/images/logo.png" width="auto" alt="Logo">
|
||||
<p>
|
||||
<strong>A lightweight Transformer training & inference framework</strong>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<h2 id="english">English Version</h2>
|
||||
<div align="center">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
|
||||
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
|
||||
</div>
|
||||
<br>
|
||||
|
||||
This is a Chinese-English bilingual Transformer model supporting both languages. It contains model configurations and training workflows, completing training by loading parameters defined in `param_path/config.json`. The training script `train.py` parses command-line arguments, including dataset root directory, number of training epochs, batch size, checkpoint interval, and checkpoint directory.
|
||||
<div align="center">
|
||||
<a href="#english">English</a> •
|
||||
<a href="assets/docs/README-zh-CN.md">中文</a> •
|
||||
<a href="https://github.com/ViperEkura/AstrAI/issues">Issue Tracker</a> •
|
||||
<a href="https://github.com/ViperEkura/AstrAI/discussions">Discussions</a> •
|
||||
<a href="https://huggingface.co/ViperEk/">HuggingFace</a>
|
||||
</div>
|
||||
|
||||
**Model Download Options (Choose One):**
|
||||
<br>
|
||||
|
||||
1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) to access **Files and versions**
|
||||
2. Run `scripts/download.py` to download parameters
|
||||
## 📖 Table of Contents
|
||||
|
||||
**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
||||
- [Features](#features)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Documentation](#documentation)
|
||||
- [Contributing](#contributing)
|
||||
- [Community](#community)
|
||||
- [License](#license)
|
||||
|
||||
Training dataset sources are listed in the **Model Card** section of the HuggingFace download link.
|
||||
---
|
||||
|
||||
**License:** Code follows Apache-2.0 protocol. Please credit the source code when used.
|
||||
<a id="english"></a>
|
||||
## English
|
||||
|
||||
- **📊 Device Selection:** Code defaults to CUDA training
|
||||
- **🌐 Performance Optimization:** `dtype=torch.bfloat16` is enabled to accelerate training and reduce memory usage. Ensure hardware supports this feature.
|
||||
- **🤖 Language Support:** Model supports Chinese and English training. The BBPE tokenizer was trained without multilingual text, so OOV (out-of-vocabulary) issues are minimized for these languages but may exist for others.
|
||||
### Features
|
||||
|
||||
### 📌 Training Guide
|
||||
- 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization.
|
||||
- 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures.
|
||||
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
||||
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
||||
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
||||
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
|
||||
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
|
||||
|
||||
To train this Transformer model, follow these steps:
|
||||
### Quick Start
|
||||
|
||||
**(1). Prepare Dataset:**
|
||||
|
||||
Place datasets in the designated root directory. Files should be text documents in Chinese, English, or mixed. Format should align with model input requirements - preferably pre-tokenized token_ids stored as `torch.Tensor` (using `torch.Tensor` saves memory compared to Python lists, which default to 64-bit precision).
|
||||
|
||||
**(2). Install Dependencies:**
|
||||
#### Installation
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install .
|
||||
git clone https://github.com/ViperEkura/AstrAI.git
|
||||
cd AstrAI
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
**(3). Run Training Script:**
|
||||
For development dependencies:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--train_type=train_type[seq, sft, dpo] \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/param_path \
|
||||
--n_epoch=5 \
|
||||
--batch_size=8 \
|
||||
--max_lr=2e-4 \
|
||||
--checkpoint_interval=10000 \
|
||||
--checkpoint_dir=checkpoints
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
**Parameters Explanation:**
|
||||
- `--train_type`: Training type (seq, sft, dpo)
|
||||
- `--data_root_path`: Root directory of the dataset
|
||||
- `--param_path`: Path to the model training parameters
|
||||
- `--n_epoch`: Total number of training epochs
|
||||
- `--batch_size`: Batch size
|
||||
- `--accumulation_steps`: Number of batches per training step
|
||||
- `--warmup_steps`: Number of warmup steps
|
||||
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
||||
- `--checkpoint_interval`: Checkpoint saving interval
|
||||
- `--checkpoint_dir`: Directory to save checkpoints
|
||||
- `--resume_dir`: Resume training from the specified path
|
||||
#### Download Pre-trained Model
|
||||
|
||||
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
|
||||
|
||||
### 👉 Usage Guide
|
||||
|
||||
**(1). Chatting with the Model:**
|
||||
|
||||
Open `chat.py` or use streaming/non-streaming interfaces:
|
||||
|
||||
**Streaming Output:**
|
||||
```python
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
|
||||
model_dir = "your_model_parameter_dir"
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
history = []
|
||||
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "!exit":
|
||||
break
|
||||
|
||||
response_size = 0
|
||||
for response, history in model.stream_generate(
|
||||
query=query,
|
||||
history=history,
|
||||
temperature=0.85,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
):
|
||||
print(response[response_size:], end="")
|
||||
response_size = len(response)
|
||||
```
|
||||
|
||||
**Non-streaming Output:**
|
||||
```python
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
|
||||
model_dir = "your_model_parameter_dir"
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
history = []
|
||||
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "!exit":
|
||||
break
|
||||
|
||||
response = model.generate(
|
||||
query=query,
|
||||
history=history,
|
||||
temperature=0.85,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
**(2) Retrieval-Augmented Generation (RAG):**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
|
||||
model_dir = "your_model_parameter_dir"
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
retrieved_content = model.retrieve_generate(
|
||||
query=query,
|
||||
retrieve_top_k=5,
|
||||
temperature=0.6,
|
||||
top_k=30,
|
||||
top_p=0.95
|
||||
)
|
||||
print(retrieved_content)
|
||||
```
|
||||
|
||||
### 📌 Model Specifications
|
||||
|
||||
This model is based on a 24-layer Transformer with parameters defined in `config.json`, totaling approximately 1.0 billion (1.0B) parameters.
|
||||
|
||||
**Key Design Choices:**
|
||||
- Weight tying between embedding and final linear layers (standard for small models to save parameters)
|
||||
- Embedding layer optimization: Without weight tying, a 10,000-word vocabulary would consume ~102M parameters (0.1B)
|
||||
|
||||
**Limitations:**
|
||||
- May struggle with complex language phenomena due to smaller parameter size
|
||||
- Prone to overfitting on specialized datasets
|
||||
- Limited multilingual capabilities
|
||||
|
||||
**Advantages:**
|
||||
- Runs efficiently on lower-spec hardware
|
||||
- Shorter training time compared to larger models
|
||||
|
||||
**Training Pipeline:**
|
||||
The model has completed pre-training + SFT (Supervised Fine-Tuning) + DPO (Direct Preference Optimization) workflows. All corresponding training code is included in the repository.
|
||||
|
||||
|
||||
<h2 id="chinese">中文版本</h2>
|
||||
这是一个支持中英文双语的 Transformer 模型,能够处理两种语言。模型包含配置文件和训练流程,通过加载 `param_path/config.json` 中定义的参数完成训练。训练脚本 `train.py` 支持命令行参数解析,包括数据集根目录、训练轮数(epochs)、批量大小(batch size)、检查点保存间隔、检查点目录等。
|
||||
|
||||
**模型下载选项(任选其一):**
|
||||
|
||||
1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
|
||||
2. 运行 `scripts/download.py` 下载模型参数
|
||||
|
||||
**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
||||
|
||||
训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
|
||||
|
||||
**许可证:** 代码遵循 Apache-2.0 协议,使用时请注明出处。
|
||||
|
||||
- **📊 设备选择:** 默认使用 CUDA 进行训练
|
||||
- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
|
||||
- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV(未登录词)问题较少,其他语言可能存在 OOV 问题
|
||||
|
||||
|
||||
|
||||
### 📌 训练指南
|
||||
|
||||
要训练该 Transformer 模型,请按照以下步骤操作:
|
||||
|
||||
#### **(1). 准备数据集:**
|
||||
|
||||
将数据集放置在指定的根目录下。文件应为包含中文、英文或混合文本的文本文档。格式应符合模型输入要求——建议使用预分词后的 `token_ids` 并以 `torch.Tensor` 格式保存(使用 `torch.Tensor` 相比 Python 列表更节省内存,列表默认为 64 位精度)。
|
||||
|
||||
#### **(2). 安装依赖:**
|
||||
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install .
|
||||
python scripts/demo/download.py
|
||||
```
|
||||
|
||||
#### **(3). 运行训练脚本:**
|
||||
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
|
||||
|
||||
#### Train a Model
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--train_type=train_type[seq, sft, dpo] \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/param_path \
|
||||
--n_epoch=5 \
|
||||
--batch_size=8 \
|
||||
--max_lr=2e-4 \
|
||||
--checkpoint_interval=10000 \
|
||||
--checkpoint_dir=checkpoints
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
**参数说明:**
|
||||
- `--train_type`: 训练类型(seq, sft, dpo)
|
||||
- `--data_root_path`: 数据集根目录
|
||||
- `--param_path`: 模型训练参数路径
|
||||
- `--n_epoch`: 总训练轮数
|
||||
- `--batch_size`: 批量大小
|
||||
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
||||
- `--warmup_steps`: 预热步数(warmup steps)
|
||||
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
||||
- `--checkpoint_interval`: 检查点保存间隔
|
||||
- `--checkpoint_dir`: 检查点保存目录
|
||||
- `--resume_dir`: 从指定路径恢复训练
|
||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||
|
||||
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。
|
||||
#### Generate Text
|
||||
|
||||
|
||||
|
||||
### 👉 使用指南
|
||||
|
||||
#### **(1). 与模型对话:**
|
||||
|
||||
打开 `chat.py` 或使用流式/非流式接口:
|
||||
|
||||
**流式输出:**
|
||||
```python
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
|
||||
model_dir = "your_model_parameter_dir"
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
history = []
|
||||
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "!exit":
|
||||
break
|
||||
|
||||
response_size = 0
|
||||
for response, history in model.stream_generate(
|
||||
query=query,
|
||||
history=history,
|
||||
temperature=0.85,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
):
|
||||
print(response[response_size:], end="")
|
||||
response_size = len(response)
|
||||
```bash
|
||||
python scripts/tools/generate.py \
|
||||
--param_path /path/to/model \
|
||||
--input_json_file /path/to/input.jsonl \
|
||||
--output_json_file /path/to/output.jsonl
|
||||
```
|
||||
|
||||
**非流式输出:**
|
||||
```python
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
#### Docker
|
||||
|
||||
model_dir = "your_model_parameter_dir"
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
history = []
|
||||
Build and run with Docker (recommended for GPU environments):
|
||||
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "!exit":
|
||||
break
|
||||
|
||||
response = model.generate(
|
||||
query=query,
|
||||
history=history,
|
||||
temperature=0.85,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
)
|
||||
print(response)
|
||||
```bash
|
||||
# Build image
|
||||
docker build -t astrai:latest .
|
||||
|
||||
# Run with GPU support
|
||||
docker run --gpus all -it astrai:latest
|
||||
|
||||
# Run with specific GPUs
|
||||
docker run --gpus '"device=0,1"' -it astrai:latest
|
||||
|
||||
# Run inference server
|
||||
docker run --gpus all -p 8000:8000 astrai:latest \
|
||||
python -m scripts.tools.server --port 8000 --device cuda
|
||||
|
||||
# Run with volume mount for data
|
||||
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
||||
|
||||
# Docker Compose (GPU, default)
|
||||
docker compose up -d
|
||||
|
||||
# Docker Compose (CPU only)
|
||||
docker compose --profile cpu up -d
|
||||
```
|
||||
|
||||
#### **(2). 基于检索的生成(RAG):**
|
||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
#### Start HTTP Server
|
||||
|
||||
model_dir = "your_model_parameter_dir"
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
|
||||
|
||||
retrieved_content = model.retrieve_generate(
|
||||
query=query,
|
||||
retrieve_top_k=5,
|
||||
temperature=0.6,
|
||||
top_k=30,
|
||||
top_p=0.95
|
||||
)
|
||||
print(retrieved_content)
|
||||
```bash
|
||||
python -m scripts.tools.server --port 8000 --device cuda
|
||||
```
|
||||
|
||||
Make requests:
|
||||
|
||||
```bash
|
||||
# OpenAI-compatible
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 512
|
||||
}'
|
||||
|
||||
### 📌 模型规格说明(重复部分)
|
||||
# OpenAI-compatible streaming
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Tell a story"}],
|
||||
"stream": true,
|
||||
"max_tokens": 500
|
||||
}'
|
||||
|
||||
该模型基于一个 24 层的 Transformer 架构,参数配置定义在 `config.json` 中,总参数量约为 10 亿(1.0B)。
|
||||
# Anthropic-compatible
|
||||
curl -X POST http://localhost:8000/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "astrai",
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 512
|
||||
}'
|
||||
|
||||
**关键设计选择:**
|
||||
- 在嵌入层(embedding)与最终线性层之间进行权重绑定(weight tying),这是小型模型中常见的节省参数量的做法
|
||||
- 嵌入层优化:若不进行权重绑定,一个包含 10,000 个词的词汇表将消耗约 1.02 亿(0.1B)参数
|
||||
# Anthropic-compatible streaming with stop sequences
|
||||
curl -X POST http://localhost:8000/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "astrai",
|
||||
"messages": [{"role": "user", "content": "Write a story"}],
|
||||
"max_tokens": 500,
|
||||
"stream": true,
|
||||
"stop_sequences": ["The end"]
|
||||
}'
|
||||
|
||||
**局限性:**
|
||||
- 由于参数规模较小,可能在处理复杂语言现象时表现受限
|
||||
- 在特定领域的数据集上容易出现过拟合
|
||||
- 多语言能力有限
|
||||
# Health check
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
**优势:**
|
||||
- 可在低配置硬件上高效运行
|
||||
- 相较于大型模型,训练时间更短
|
||||
#### Demo
|
||||
|
||||
**训练流程:**
|
||||
该模型已完成预训练(pre-training)+ 监督微调(SFT, Supervised Fine-Tuning)+ 直接偏好优化(DPO, Direct Preference Optimization)的全流程。所有相关的训练代码均已包含在代码库中。
|
||||
Check out the demos in the `scripts/demo/` folder:
|
||||
|
||||
```bash
|
||||
# Download pre‑processed data (required before running demos)
|
||||
python scripts/demo/download.py
|
||||
|
||||
# Interactive streaming chat
|
||||
python scripts/demo/stream_chat.py
|
||||
|
||||
# Batch generation
|
||||
python scripts/demo/generate_batch.py
|
||||
|
||||
# Auto‑regressive generation
|
||||
python scripts/demo/generate_ar.py
|
||||
```
|
||||
|
||||
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
|
||||
|
||||
### Documentation
|
||||
|
||||
| Document | Description |
|
||||
|----------|-------------|
|
||||
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
|
||||
| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns |
|
||||
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
||||
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
||||
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
||||
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
|
||||
|
||||
### Contributing
|
||||
|
||||
We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
|
||||
|
||||
1. Fork the repository.
|
||||
2. Create a feature branch.
|
||||
3. Commit your changes.
|
||||
4. Open a Pull Request.
|
||||
|
||||
For major changes, please open an issue first to discuss what you would like to change.
|
||||
|
||||
### Community
|
||||
|
||||
- **GitHub Issues**: [Issue Tracker](https://github.com/ViperEkura/AstrAI/issues)
|
||||
- **Discussions**: [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions)
|
||||
- **HuggingFace**: [Model Hub](https://huggingface.co/ViperEk)
|
||||
|
||||
### License
|
||||
|
||||
This project is licensed under the [GPL-3.0 License](LICENSE).
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<em>A lightweight Transformer framework designed for both high performance and ease of use.</em>
|
||||
</div>
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
<div align="center">
|
||||
|
||||
<img src="../images/logo.png" width="auto" alt="Logo">
|
||||
|
||||
<div>
|
||||
<a href="../../README.md">English</a> •
|
||||
<a href="#chinese">中文</a>
|
||||
</div>
|
||||
|
||||
<p>
|
||||
<strong>轻量级 Transformer 训练与推理框架</strong>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
|
||||
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div align="center">
|
||||
<a href="../../README.md">English</a> •
|
||||
<a href="#chinese">中文</a> •
|
||||
<a href="https://github.com/ViperEkura/AstrAI/issues">问题追踪</a> •
|
||||
<a href="https://github.com/ViperEkura/AstrAI/discussions">讨论区</a> •
|
||||
<a href="https://huggingface.co/ViperEk">HuggingFace</a>
|
||||
</div>
|
||||
<br>
|
||||
|
||||
## 📖 目录
|
||||
|
||||
- [特性](#特性)
|
||||
- [快速开始](#快速开始)
|
||||
- [文档](#文档)
|
||||
- [贡献](#贡献)
|
||||
- [社区](#社区)
|
||||
- [许可证](#许可证)
|
||||
|
||||
---
|
||||
|
||||
<a id="chinese"></a>
|
||||
## 中文
|
||||
|
||||
### 特性
|
||||
|
||||
- 🚀 **高性能**: 训练与推理双向优化,高效并行。
|
||||
- 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。
|
||||
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
||||
- 📦 **轻量**: 依赖少,部署简单。
|
||||
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
||||
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
|
||||
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
|
||||
|
||||
### 快速开始
|
||||
|
||||
#### 安装
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ViperEkura/AstrAI.git
|
||||
cd AstrAI
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
安装开发依赖:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
#### 下载预训练模型
|
||||
|
||||
下载预训练模型权重(1B 双语检查点)到 `params/` 目录:
|
||||
|
||||
```bash
|
||||
python scripts/demo/download.py
|
||||
```
|
||||
|
||||
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。
|
||||
|
||||
#### 训练模型
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
完整参数列表见[参数说明](./params.md)。
|
||||
|
||||
#### 文本生成
|
||||
|
||||
```bash
|
||||
python scripts/tools/generate.py \
|
||||
--param_path /path/to/model \
|
||||
--input_json_file /path/to/input.jsonl \
|
||||
--output_json_file /path/to/output.jsonl
|
||||
```
|
||||
|
||||
#### Docker
|
||||
|
||||
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
||||
|
||||
```bash
|
||||
# 构建镜像
|
||||
docker build -t astrai:latest .
|
||||
|
||||
# 启用 GPU 运行
|
||||
docker run --gpus all -it astrai:latest
|
||||
|
||||
# 指定特定 GPU
|
||||
docker run --gpus '"device=0,1"' -it astrai:latest
|
||||
|
||||
# 运行推理服务
|
||||
docker run --gpus all -p 8000:8000 astrai:latest \
|
||||
python -m scripts.tools.server --port 8000 --device cuda
|
||||
|
||||
# 挂载数据卷
|
||||
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
||||
|
||||
# Docker Compose(GPU,默认)
|
||||
docker compose up -d
|
||||
|
||||
# Docker Compose(仅 CPU)
|
||||
docker compose --profile cpu up -d
|
||||
```
|
||||
|
||||
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`。
|
||||
|
||||
#### 启动 HTTP 服务
|
||||
|
||||
启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API:
|
||||
|
||||
```bash
|
||||
python -m scripts.tools.server --port 8000 --device cuda
|
||||
```
|
||||
|
||||
发起请求:
|
||||
|
||||
```bash
|
||||
# OpenAI 兼容
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "你好"}],
|
||||
"max_tokens": 512
|
||||
}'
|
||||
|
||||
# OpenAI 兼容流式
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "讲个故事"}],
|
||||
"stream": true,
|
||||
"max_tokens": 500
|
||||
}'
|
||||
|
||||
# Anthropic 兼容
|
||||
curl -X POST http://localhost:8000/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "astrai",
|
||||
"system": "你是一个乐于助人的助手。",
|
||||
"messages": [{"role": "user", "content": "你好"}],
|
||||
"max_tokens": 512
|
||||
}'
|
||||
|
||||
# Anthropic 兼容流式并设置停止序列
|
||||
curl -X POST http://localhost:8000/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "astrai",
|
||||
"messages": [{"role": "user", "content": "写个故事"}],
|
||||
"max_tokens": 500,
|
||||
"stream": true,
|
||||
"stop_sequences": ["结束"]
|
||||
}'
|
||||
|
||||
# 健康检查
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
#### 演示
|
||||
|
||||
查看 `scripts/demo/` 文件夹中的演示:
|
||||
|
||||
```bash
|
||||
# 下载预处理数据(运行演示前必需)
|
||||
python scripts/demo/download.py
|
||||
|
||||
# 交互式流式聊天
|
||||
python scripts/demo/stream_chat.py
|
||||
|
||||
# 批量生成
|
||||
python scripts/demo/generate_batch.py
|
||||
|
||||
# 自回归生成
|
||||
python scripts/demo/generate_ar.py
|
||||
```
|
||||
|
||||
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
|
||||
|
||||
### 文档
|
||||
|
||||
| 文档 | 说明 |
|
||||
|------|------|
|
||||
| [参数说明](./params.md) | 训练与推理参数配置 |
|
||||
| [架构文档](./architecture.md) | 系统架构、类图与设计模式 |
|
||||
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
||||
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
||||
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
||||
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
|
||||
|
||||
### 贡献
|
||||
|
||||
我们欢迎贡献!请参阅[贡献指南](../../CONTRIBUTING.md)了解详情。
|
||||
|
||||
1. Fork 本仓库。
|
||||
2. 创建功能分支。
|
||||
3. 提交更改。
|
||||
4. 发起 Pull Request。
|
||||
|
||||
重大更改请先开 issue 讨论。
|
||||
|
||||
### 社区
|
||||
|
||||
- **GitHub Issues**: [问题追踪](https://github.com/ViperEkura/AstrAI/issues)
|
||||
- **Discussions**: [GitHub 讨论区](https://github.com/ViperEkura/AstrAI/discussions)
|
||||
- **HuggingFace**: [模型中心](https://huggingface.co/ViperEk)
|
||||
|
||||
### 许可证
|
||||
|
||||
本项目采用 [GPL-3.0 许可证](../../LICENSE)。
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
<em>专为高性能与易用性设计的轻量级 Transformer 框架。</em>
|
||||
</div>
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,64 @@
|
|||
# Data Flow
|
||||
|
||||
This document describes the data pipeline: from raw text to model input tensors.
|
||||
|
||||
## Overview
|
||||
|
||||
```
|
||||
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||
|
||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||
|
||||
```
|
||||
StoreFactory.create("h5") → H5Store
|
||||
StoreFactory.create("bin") → MmapStore
|
||||
```
|
||||
|
||||
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||
|
||||
## Data Keys by Training Type
|
||||
|
||||
| Type | Storage Keys |
|
||||
|------|-------------|
|
||||
| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) |
|
||||
| `sft` | `sequence`, `loss_mask` |
|
||||
| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` |
|
||||
| `grpo` | `prompts`, `responses`, `masks`, `rewards` |
|
||||
|
||||
## Dataset Architecture
|
||||
|
||||
```
|
||||
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
|
||||
→ BaseDataset.load(load_path, storage_type=None)
|
||||
→ detect_format(load_path)
|
||||
→ StoreFactory.create(storage_type)
|
||||
→ Store.load(load_path)
|
||||
→ H5Store._normalize() / MmapStore._normalize()
|
||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ get_index(idx) → [begin, end)
|
||||
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
|
||||
```
|
||||
|
||||
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
|
||||
|
||||
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
|
||||
|
||||
## Sampler
|
||||
|
||||
`ResumableDistributedSampler` supports checkpoint-aware distributed sampling:
|
||||
|
||||
- Tracks `start_epoch` / `start_iter` for resume
|
||||
- Shuffle via `torch.Generator(seed + epoch)`
|
||||
- Per-replica index slicing for DDP
|
||||
|
||||
## DataLoader
|
||||
|
||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
# Inference
|
||||
|
||||
## KV Cache
|
||||
|
||||
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
|
||||
|
||||
$$
|
||||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
|
||||
$$
|
||||
|
||||
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
|
||||
|
||||
## KVCache System
|
||||
|
||||
Six classes (plus two helpers) working together:
|
||||
|
||||
```
|
||||
KVCache (facade)
|
||||
├── PagePool orchestrates page allocation + prefix matching
|
||||
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
||||
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
||||
├── TaskTable maps task_id → page_table + cached token count
|
||||
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
||||
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
||||
```
|
||||
|
||||
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
||||
|
||||
## Continuous Batching
|
||||
|
||||
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
|
||||
|
||||
```
|
||||
1. Cleanup → Remove finished tasks, free KV pages
|
||||
2. Refill → Pop from waiting_queue, task_alloc pages, activate
|
||||
3. Prefill → Group by (prompt_len, start_pos), run full forward
|
||||
4. Decode → Pick largest same-position group, single-token forward
|
||||
```
|
||||
|
||||
## Sampling (Strategy Pattern)
|
||||
|
||||
```
|
||||
BaseSamplingStrategy (ABC)
|
||||
├── TemperatureStrategy
|
||||
├── TopKStrategy
|
||||
├── TopPStrategy
|
||||
└── SamplingPipeline
|
||||
```
|
||||
|
||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||
`sample()` is a convenience shortcut for one-shot usage.
|
||||
|
||||
## Protocol Handlers (Strategy Pattern)
|
||||
|
||||
```python
|
||||
class ProtocolHandler: # concrete orchestrator
|
||||
def __init__(self, request, engine, builder): ...
|
||||
async def handle(self):
|
||||
prompt, ctx, stops = builder.prepare(request, engine)
|
||||
agen = engine.generate_async(prompt, ...)
|
||||
if stream: self._handle_stream(agen, ctx, stops)
|
||||
else: return await self._handle_non_stream(agen, ctx, stops)
|
||||
```
|
||||
|
||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||
|
||||
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
||||
|
||||
Adding a protocol = one builder file, no handler subclassing needed.
|
||||
|
||||
## Engine & GenerateResult
|
||||
|
||||
```
|
||||
InferenceEngine
|
||||
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
||||
├── generate_with_request(req) → same
|
||||
├── generate_async(prompt, ...) → AsyncGenerator
|
||||
├── get_stats() → Dict
|
||||
└── shutdown()
|
||||
```
|
||||
|
||||
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
||||
|
||||
## HTTP API
|
||||
|
||||
```
|
||||
POST /v1/chat/completions OpenAI
|
||||
POST /v1/messages Anthropic
|
||||
GET /health {"status":"ok","model_loaded":true}
|
||||
GET /stats scheduler statistics
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1717000000,
|
||||
"model": "astrai",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||
}
|
||||
```
|
||||
|
||||
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
||||
|
||||
### Anthropic
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
|
||||
```
|
||||
|
||||
Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||
|
||||
### GenerationRequest Parameters
|
||||
|
||||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||
| `top_k` | int | 50 | Top-k count |
|
||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
|
||||
| `max_tokens` | Optional[int] | None | Max generation length |
|
||||
| `stream` | bool | False | Stream output |
|
||||
|
||||
## Engine API
|
||||
|
||||
```python
|
||||
# Non-streaming
|
||||
engine.generate("Hello", stream=False) # -> str
|
||||
engine.generate(["A", "B"], stream=False) # -> List[str]
|
||||
|
||||
# Streaming
|
||||
engine.generate("Hello", stream=True) # -> Generator[str]
|
||||
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||
|
||||
# Async
|
||||
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
|
||||
print(token)
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
## 模型介绍
|
||||
|
||||
|
||||
|
||||
### 1. 模型搭建
|
||||
|
||||
本模型采用Transformer架构, 使用GQA(q_head=24, kv_head=4) 机制,相较于传统的MHA可以节省KV cache 的显存占用(但是目前没有做KV cache),通过堆叠24层Transformer实现模型的搭建, 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布
|
||||
|
||||

|
||||
|
||||
什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文(即已经出现的tokens序列),计算出下一个可能的token及其对应的概率。
|
||||
|
||||
|
||||
|
||||
#### 1. 自回归
|
||||
|
||||
假设我们有一个句子被拆分成如下tokens列表:
|
||||
|
||||
```
|
||||
["你好", "," "今天", "天气"]
|
||||
```
|
||||
|
||||
接下来,模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出,比如:
|
||||
|
||||
```
|
||||
-> {"token": "不错", "probability": 0.4}
|
||||
-> {"token": "晴朗", "probability": 0.2}
|
||||
-> ......
|
||||
```
|
||||
|
||||
这里,“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens,并且给出了每个token成为下一个token的可能性大小。
|
||||
|
||||
之后,我们通过采样(通过top_k, top_p, temperature参数调整采样后的结果)得到下一个token并且将下一个token加入序列作为输入
|
||||
|
||||
```
|
||||
["你好", "," "今天", "天气", "不错"]
|
||||
```
|
||||
|
||||
之后都是在重复这个流程, 直到遇到控制流程结束的token(<|end_of_seqence|>)模型停止处理(一般模型都会设置控制token, 不然模型会一直输出到显存爆炸)。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#### 2. 因果掩码
|
||||
|
||||
transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len], 输出为[bsz, seq_len,n_dim], 为了实现预测下一个token, 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置, 在训练的时候我们也采用错开一个位置的方法
|
||||
|
||||
```
|
||||
sequence : [[1, 2, 3, 4, 5, 6]]
|
||||
input_ids: [[1, 2, 3, 4, 5]]
|
||||
target_ids: [[2, 3, 4, 5, 6]]
|
||||
```
|
||||
|
||||
|
||||
|
||||
注意力得分计算的公式为
|
||||
|
||||
|
||||
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
|
||||
$$ s_{ij} := s_{ij} + mask_{ij} $$
|
||||
|
||||
|
||||
其中注意力得分代表了模型对两个token之间相似程度的关注程度
|
||||
|
||||
对于decoder only结构的模型, 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码,我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵,对于长度为n的序列,它的形状是[n, n]。下面以一个长度为5的序列为例,展示如何创建这样的因果掩码矩阵:
|
||||
|
||||
```
|
||||
[[0, -inf, -inf, -inf, -inf],
|
||||
[0, 0, -inf, -inf, -inf],
|
||||
[0, 0, 0, -inf, -inf],
|
||||
[0, 0, 0, 0, -inf],
|
||||
[0, 0, 0, 0, 0]]
|
||||
```
|
||||
|
||||
在这个矩阵中,0表示可以注意到的位置,而-inf表示应该被掩盖(即不应注意到)的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0, 也就是模型不能看到未来的信息
|
||||
|
||||
|
||||
|
||||
#### 3. 旋转位置编码
|
||||
|
||||
旋转位置编码(Rotary Position Embedding, RoPE)是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码(如正弦和余弦函数的位置编码)不同,RoPE通过将位置信息直接嵌入到查询(Query, Q)和键(Key, K)向量中来实现,使得模型能够更自然地处理序列中的相对位置关系。
|
||||
|
||||
|
||||
$$ q_i = R_i W_q x_i $$
|
||||
$$ k_j = R_j W_k x_j $$
|
||||
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||||
|
||||
其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
## kv_cache 实现
|
||||
|
||||
根据注意力的计算公式
|
||||
|
||||
$$
|
||||
\begin{align*}
|
||||
o_i &= \sum_j s_{ij} v_{j} \\
|
||||
s_{ij} &= \text{softmax}\left( \sum_n \frac{q_{i,n} k_{j,n}}{\sqrt{d_k}} \right)
|
||||
\end{align*}
|
||||
$$
|
||||
|
||||
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
|
||||
|
||||
$$
|
||||
\begin{align*}
|
||||
o_n &= \sum_j s_{j}v_{j,n} \\
|
||||
s_j &= \text{softmax}\left(\sum_n\frac{q_n k_{j,n}}{\sqrt{d_k}} \right)
|
||||
\end{align*}
|
||||
$$
|
||||
|
||||
如果我们把式子展开
|
||||
|
||||
$$
|
||||
o_n = \sum_j \sum_n \text{softmax}\left(\frac{q_n k_{j,n}}{\sqrt{d_k}}\right)v_{j,n}
|
||||
$$
|
||||
|
||||
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
# Parameter Documentation
|
||||
|
||||
## Training Parameters
|
||||
|
||||
### Basic Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
||||
| `--data_root_path` | Dataset root directory | required |
|
||||
| `--param_path` | Model parameters or checkpoint path | required |
|
||||
| `--n_epoch` | Total training epochs | 1 |
|
||||
| `--batch_per_device` | Batch size per device | 1 |
|
||||
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||
|
||||
### Learning Rate Scheduling
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||
|
||||
### Optimizer (AdamW)
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||
|
||||
### Data Loading
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--window_size` | Max input sequence length | model config `max_len` |
|
||||
| `--stride` | Stride for sliding window over sequences | None |
|
||||
| `--random_seed` | Random seed for reproducibility | 3407 |
|
||||
| `--num_workers` | DataLoader worker processes | 4 |
|
||||
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
|
||||
|
||||
### Checkpoint & Resume
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
|
||||
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
||||
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
|
||||
| `--start_batch` | Resume from batch iteration | 0 |
|
||||
|
||||
### Distributed Training
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
||||
| `--device_type` | Device type | cuda |
|
||||
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||
|
||||
### Strategy-specific
|
||||
|
||||
| Parameter | Description | Default | Used by |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
||||
|
||||
### Usage Example
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
> Document Update Time: 2026-05-24
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
# Preprocessing Pipeline
|
||||
|
||||
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
|
||||
|
||||
## Philosophy
|
||||
|
||||
| Component | Responsibility |
|
||||
|-----------|---------------|
|
||||
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
||||
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
||||
|
||||
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### SFT Chat
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"input": {
|
||||
"type": "chat",
|
||||
"messages_key": "messages"
|
||||
},
|
||||
"mask": {
|
||||
"system": "mask",
|
||||
"user": "mask",
|
||||
"assistant": "train"
|
||||
},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048,
|
||||
"deduplicate": true
|
||||
},
|
||||
"output": {
|
||||
"domain_key": "source",
|
||||
"storage_format": "bin",
|
||||
"max_tokens_per_shard": 100000000
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
|
||||
|
||||
### Instruction Tuning
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"input": {
|
||||
"type": "instruction",
|
||||
"prompt_key": "instruction",
|
||||
"response_key": "output"
|
||||
},
|
||||
"mask": {
|
||||
"prompt": "mask",
|
||||
"response": "train"
|
||||
},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048
|
||||
},
|
||||
"output": {
|
||||
"storage_format": "bin"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Mask splits at the prompt/response field boundary.
|
||||
|
||||
### Pretraining
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"input": {
|
||||
"type": "text",
|
||||
"text_key": "content"
|
||||
},
|
||||
"mask": {},
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048,
|
||||
"min_chars": 50
|
||||
},
|
||||
"output": {
|
||||
"storage_format": "bin"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
No mask -- train on all tokens.
|
||||
|
||||
### Run
|
||||
|
||||
```bash
|
||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### `input`
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
|
||||
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
|
||||
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
|
||||
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
|
||||
| `text_key` | string | no | `"text"` | JSON key for text field |
|
||||
|
||||
### `mask`
|
||||
|
||||
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
|
||||
|
||||
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
|
||||
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
|
||||
|
||||
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
|
||||
For instruction mode, keys are `"prompt"` and `"response"`.
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `mask` | dict | `{}` | Role/field to action mapping |
|
||||
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
|
||||
|
||||
### `preprocessing`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
|
||||
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
|
||||
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
|
||||
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
|
||||
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
|
||||
|
||||
### `output`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
|
||||
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
|
||||
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
|
||||
|
||||
## Mask Algorithm
|
||||
|
||||
### Chat Mode (role-span tracking)
|
||||
|
||||
For each message in the `messages` array:
|
||||
|
||||
1. Prepend BOS token (position 0, always masked)
|
||||
2. Render through the chat template for that single message
|
||||
3. Encode the rendered text, record token span `(start, end, role)`
|
||||
4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
|
||||
5. Fill `loss_mask` from the mask rules
|
||||
|
||||
**Multi-turn example**:
|
||||
|
||||
```
|
||||
Data:
|
||||
[system: "You are helpful."]
|
||||
[user: "What is 2+2?"]
|
||||
[assistant: "4"]
|
||||
[user: "What is 3+3?"]
|
||||
[assistant: "6"]
|
||||
|
||||
Config:
|
||||
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
|
||||
|
||||
Result:
|
||||
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
|
||||
mask: 0 0 0 1 0 1
|
||||
```
|
||||
|
||||
Both assistant turns are trained. All system and user tokens are masked.
|
||||
|
||||
### Instruction Mode (field boundary)
|
||||
|
||||
Encode the prompt and response fields independently, then split the mask at the field boundary.
|
||||
|
||||
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
|
||||
- `"prompt": "train", "response": "mask"` -- the reverse
|
||||
|
||||
### Text Mode (no mask)
|
||||
|
||||
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
|
||||
|
||||
## Output Layout
|
||||
|
||||
### Single-Shard (`bin`)
|
||||
|
||||
```
|
||||
output_dir/
|
||||
__default__/ # when domain_key is null
|
||||
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
|
||||
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
|
||||
loss_mask.bin # int64 raw bytes
|
||||
wiki/ # when domain_key="source" and item["source"]="wiki"
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
```
|
||||
|
||||
### Multi-Shard (`bin`)
|
||||
|
||||
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
|
||||
|
||||
```
|
||||
output_dir/
|
||||
__default__/
|
||||
shard_0000/
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
shard_0001/
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
```
|
||||
|
||||
`MmapStore` automatically discovers and merges all shards under the domain directory.
|
||||
|
||||
### H5 Output
|
||||
|
||||
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
|
||||
|
||||
```
|
||||
output_dir/
|
||||
__default__/
|
||||
data_0000.h5 # each H5 contains key→dataset groups
|
||||
data_0001.h5
|
||||
wiki/
|
||||
data_0000.h5
|
||||
```
|
||||
|
||||
## Python API Usage
|
||||
|
||||
```python
|
||||
from astrai.preprocessing.pipeline import Pipeline
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
|
||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||
Pipeline(
|
||||
config,
|
||||
["data_part1.jsonl", "data_part2.jsonl"],
|
||||
output_dir="output/",
|
||||
tokenizer_path="params"
|
||||
).run()
|
||||
```
|
||||
|
||||
Or from the CLI:
|
||||
|
||||
```bash
|
||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
||||
```
|
||||
|
||||
## Extension
|
||||
|
||||
Register a custom builder for new formats:
|
||||
|
||||
```python
|
||||
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
|
||||
|
||||
@MaskBuilderFactory.register("my_format")
|
||||
class MyFormatBuilder(BaseMaskBuilder):
|
||||
def build(self, item: dict, config, tokenizer) -> dict | None:
|
||||
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
|
||||
# Return None to skip this item
|
||||
...
|
||||
```
|
||||
|
||||
Then set `"input": {"type": "my_format"}` in your config.
|
||||
|
||||
## Compared to Old Pipeline
|
||||
|
||||
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|
||||
|---|---|
|
||||
| Configured via constructor arguments | Configured via JSON file |
|
||||
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
|
||||
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
|
||||
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
|
||||
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
# Training
|
||||
|
||||
### Autoregression
|
||||
|
||||
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||||
|
||||
### Causal Mask
|
||||
|
||||
```
|
||||
sequence : [[1, 2, 3, 4, 5, 6]]
|
||||
input_ids: [[1, 2, 3, 4, 5]]
|
||||
target_ids: [[2, 3, 4, 5, 6]]
|
||||
```
|
||||
|
||||
Lower-triangular mask prevents attending to future positions:
|
||||
|
||||
```
|
||||
[[0, -inf, -inf, -inf, -inf],
|
||||
[0, 0, -inf, -inf, -inf],
|
||||
[0, 0, 0, -inf, -inf],
|
||||
[0, 0, 0, 0, -inf],
|
||||
[0, 0, 0, 0, 0]]
|
||||
```
|
||||
|
||||
### Rotary Position Embedding (RoPE)
|
||||
|
||||
RoPE embeds position into Q/K vectors via complex rotation:
|
||||
|
||||
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||||
|
||||
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
|
||||
|
||||
## Training Loop
|
||||
|
||||
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
||||
|
||||
```
|
||||
on_train_begin
|
||||
model.train()
|
||||
on_epoch_begin
|
||||
for batch in dataloader:
|
||||
on_batch_begin
|
||||
with executor.accumulate(model):
|
||||
loss = strategy.compute_loss(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / executor.grad_accum_steps
|
||||
executor.backward(stand_loss)
|
||||
context.iteration += 1
|
||||
on_batch_end
|
||||
|
||||
if executor.sync_gradients:
|
||||
on_optimizer_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if scheduler:
|
||||
scheduler.step()
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
||||
### Callback Lifecycle
|
||||
|
||||
| Hook | Fires | Default callback |
|
||||
|------|-------|-----------------|
|
||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
|
||||
| `on_batch_begin` | Every batch | — |
|
||||
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
|
||||
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
|
||||
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
|
||||
|
||||
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||
|
||||
## Strategies
|
||||
|
||||
### SEQ (Pre-training)
|
||||
|
||||
Next-token cross-entropy with optional label smoothing:
|
||||
|
||||
$$
|
||||
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
||||
$$
|
||||
|
||||
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
|
||||
|
||||
### SFT (Supervised Fine-Tuning)
|
||||
|
||||
Masked cross-entropy (`ignore_index=-100`) over response tokens:
|
||||
|
||||
$$
|
||||
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
||||
$$
|
||||
|
||||
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
|
||||
|
||||
### DPO (Direct Preference Optimization)
|
||||
|
||||
Frozen reference model, preference margin via log-ratio:
|
||||
|
||||
$$
|
||||
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
||||
$$
|
||||
|
||||
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||||
|
||||
### GRPO (Group Relative Policy Optimization)
|
||||
|
||||
On-policy PPO with group-normalized advantages:
|
||||
|
||||
$$
|
||||
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
||||
$$
|
||||
|
||||
$$
|
||||
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
||||
$$
|
||||
|
||||
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
|
||||
|
||||
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||
|
||||
## LR Schedulers
|
||||
|
||||
| Type | Class | Description |
|
||||
|------|-------|-------------|
|
||||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||
|
||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
||||
|
||||
```python
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
||||
```
|
||||
|
||||
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
||||
|
||||
## Checkpoint
|
||||
|
||||
```
|
||||
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
||||
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
||||
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
|
||||
```
|
||||
|
||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
||||
|
||||
## TrainContextBuilder (Builder Pattern)
|
||||
|
||||
```python
|
||||
context = (
|
||||
TrainContextBuilder(config)
|
||||
.with_resume_dir(resume_dir)
|
||||
.build()
|
||||
)
|
||||
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||||
```
|
||||
|
||||
- Loads checkpoint weights if provided
|
||||
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
|
||||
|
||||
## Training CLI
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
Full parameter reference at [params.md](params.md).
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 281 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 11 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 590 KiB |
|
|
@ -0,0 +1,34 @@
|
|||
__version__ = "1.3.7"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from astrai.config import (
|
||||
AutoRegressiveLMConfig,
|
||||
EncoderConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference import (
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.model import AutoModel, AutoRegressiveLM
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||
|
||||
__all__ = [
|
||||
"AutoRegressiveLM",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"TrainConfig",
|
||||
"DatasetFactory",
|
||||
"AutoTokenizer",
|
||||
"GenerationRequest",
|
||||
"InferenceEngine",
|
||||
"Trainer",
|
||||
"CallbackFactory",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
"BaseFactory",
|
||||
"AutoModel",
|
||||
]
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from astrai.config.model_config import (
|
||||
AutoRegressiveLMConfig,
|
||||
BaseModelConfig,
|
||||
ConfigFactory,
|
||||
EncoderConfig,
|
||||
)
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
OutputConfig,
|
||||
PipelineConfig,
|
||||
ProcessingConfig,
|
||||
)
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
||||
__all__ = [
|
||||
"BaseModelConfig",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"ConfigFactory",
|
||||
"TrainConfig",
|
||||
"InputConfig",
|
||||
"OutputConfig",
|
||||
"PipelineConfig",
|
||||
"ProcessingConfig",
|
||||
]
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
import json
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {}
|
||||
for fld in fields(self):
|
||||
v = getattr(self, fld.name)
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
d[fld.name] = v
|
||||
elif v is None:
|
||||
d[fld.name] = None
|
||||
elif isinstance(v, (dict, list, tuple)):
|
||||
try:
|
||||
val = list(v) if isinstance(v, tuple) else v
|
||||
json.dumps(val)
|
||||
d[fld.name] = val
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
elif isinstance(v, BaseConfig):
|
||||
d[fld.name] = v.to_dict()
|
||||
elif hasattr(v, "__dataclass_fields__"):
|
||||
sub = {}
|
||||
for f in fields(v):
|
||||
a = getattr(v, f.name)
|
||||
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
||||
d[fld.name] = sub
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> Self:
|
||||
hints = get_type_hints(cls)
|
||||
inst = cls.__new__(cls)
|
||||
for fld in fields(cls):
|
||||
if fld.name in d:
|
||||
v = d[fld.name]
|
||||
target = cls._unwrap_optional(hints.get(fld.name))
|
||||
if target is not None:
|
||||
try:
|
||||
v = cls._coerce(v, target)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
object.__setattr__(inst, fld.name, v)
|
||||
elif fld.default is not MISSING:
|
||||
object.__setattr__(inst, fld.name, fld.default)
|
||||
elif fld.default_factory is not MISSING:
|
||||
object.__setattr__(inst, fld.name, fld.default_factory())
|
||||
else:
|
||||
object.__setattr__(inst, fld.name, None)
|
||||
return inst
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_optional(tp) -> Optional[type]:
|
||||
if tp is None:
|
||||
return None
|
||||
origin = getattr(tp, "__origin__", None)
|
||||
if origin is not None:
|
||||
args = getattr(tp, "__args__", ())
|
||||
non_none = [a for a in args if a is not type(None)]
|
||||
return non_none[0] if non_none else None
|
||||
return tp
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value: Any, target_type: type) -> Any:
|
||||
if target_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if (
|
||||
target_type is int
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return int(value)
|
||||
if (
|
||||
target_type is float
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return float(value)
|
||||
if target_type is str and isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
||||
return target_type.from_dict(value)
|
||||
raise TypeError
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
def to_json(self, path: Union[str, Path]):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Self
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class ConfigFactory(BaseFactory[BaseConfig]):
|
||||
"""Factory that dispatches config classes by ``model_type``."""
|
||||
|
||||
@classmethod
|
||||
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
||||
model_type = raw.get("model_type") or "autoregressive_lm"
|
||||
config_cls = cls.get_component_class(model_type)
|
||||
return config_cls.from_dict(raw)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig(BaseConfig):
|
||||
"""Base config with ``model_type`` dispatch and file I/O."""
|
||||
|
||||
model_type: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str) -> Self:
|
||||
with open(config_path, "r") as f:
|
||||
raw: Dict[str, Any] = json.load(f)
|
||||
return cls.from_dict(raw)
|
||||
|
||||
def to_file(self, config_path: str):
|
||||
d = self.to_dict()
|
||||
config_dict = {k: v for k, v in d.items() if v is not None}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ConfigFactory.register("autoregressive_lm")
|
||||
class AutoRegressiveLMConfig(BaseModelConfig):
|
||||
"""Configuration for autoregressive language model."""
|
||||
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
n_layers: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
tie_weight: Optional[bool] = None
|
||||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
|
||||
attn_type: str = "gqa"
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
kv_lora_rank: Optional[int] = None
|
||||
qk_nope_head_dim: Optional[int] = None
|
||||
qk_rope_head_dim: Optional[int] = None
|
||||
|
||||
ffn_type: str = "mlp"
|
||||
n_routed_experts: Optional[int] = None
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_activated_experts: Optional[int] = None
|
||||
topk_method: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ConfigFactory.register("embedding")
|
||||
class EncoderConfig(BaseModelConfig):
|
||||
"""Configuration for embedding encoder model."""
|
||||
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
n_layers: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
normalize_embeddings: Optional[bool] = None
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
"""Pipeline configuration for JSONL preprocessing."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputConfig(BaseConfig):
|
||||
sections: Optional[List[Dict]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig(BaseConfig):
|
||||
max_seq_len: int = 2048
|
||||
min_chars: int = 50
|
||||
max_chars: int = 2_000_000
|
||||
max_items: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig(BaseConfig):
|
||||
domain_key: Optional[str] = None
|
||||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
dtype: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig(BaseConfig):
|
||||
version: int = 1
|
||||
input: InputConfig = field(default_factory=InputConfig)
|
||||
mask: Dict[str, str] = field(default_factory=dict)
|
||||
mask_default: str = "mask"
|
||||
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
from dataclasses import dataclass, field, fields
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
from astrai.model.components.lora import LoRAConfig
|
||||
|
||||
|
||||
def required(**kw):
|
||||
return {"required": True, **kw}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model_fn: Callable[[], nn.Module] = field(
|
||||
default=None, metadata=required(help="Model factory for training.")
|
||||
)
|
||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||
dataset: Dataset = field(
|
||||
default=None, metadata=required(help="Dataset for training.")
|
||||
)
|
||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||
default=None, metadata=required(help="Optimizer factory for training.")
|
||||
)
|
||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||
default=None, metadata=required(help="Scheduler factory for training.")
|
||||
)
|
||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||
batch_per_device: int = field(
|
||||
default=4, metadata={"help": "Batch size per device."}
|
||||
)
|
||||
grad_accum_steps: int = field(
|
||||
default=1, metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
gradient_checkpointing_modules: list = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||
)
|
||||
|
||||
# checkpoint setting
|
||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||
start_batch: int = field(
|
||||
default=0, metadata={"help": "Start batch iteration for training."}
|
||||
)
|
||||
ckpt_dir: str = field(
|
||||
default="./checkpoint", metadata={"help": "Checkpoint directory."}
|
||||
)
|
||||
ckpt_interval: int = field(
|
||||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
|
||||
# lora setting
|
||||
lora: Optional[LoRAConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA config. None means full fine-tuning."},
|
||||
)
|
||||
|
||||
# metric setting
|
||||
log_dir: str = field(
|
||||
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
||||
)
|
||||
log_interval: int = field(
|
||||
default=100,
|
||||
metadata={"help": "Number of batch iterations between metric logs."},
|
||||
)
|
||||
metrics: List[str] = field(
|
||||
default_factory=lambda: ["loss", "lr"],
|
||||
metadata={"help": "Metrics to record during training."},
|
||||
)
|
||||
|
||||
# dataloader setting
|
||||
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||
num_workers: int = field(
|
||||
default=0, metadata={"help": "Number of workers for dataloader."}
|
||||
)
|
||||
prefetch_factor: Optional[int] = field(
|
||||
default=None, metadata={"help": "Prefetch factor for dataloader."}
|
||||
)
|
||||
pin_memory: bool = field(
|
||||
default=False, metadata={"help": "Pin memory for dataloader."}
|
||||
)
|
||||
|
||||
# distributed training
|
||||
nprocs: int = field(
|
||||
default=1, metadata={"help": "Number of processes for distributed training."}
|
||||
)
|
||||
backend: str = field(
|
||||
default="nccl", metadata={"help": "Distributed training backend."}
|
||||
)
|
||||
master_addr: str = field(
|
||||
default="localhost",
|
||||
metadata={"help": "Master address for distributed training."},
|
||||
)
|
||||
master_port: str = field(
|
||||
default="29500", metadata={"help": "Master port for distributed training."}
|
||||
)
|
||||
parallel_mode: str = field(
|
||||
default="none",
|
||||
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
||||
)
|
||||
start_method: str = field(
|
||||
default="spawn",
|
||||
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
|
||||
)
|
||||
|
||||
# others
|
||||
device_type: str = field(
|
||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||
)
|
||||
val_dataset: Optional[Dataset] = field(
|
||||
default=None, metadata={"help": "Dataset for validation."}
|
||||
)
|
||||
val_step: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||
)
|
||||
|
||||
executor_kwargs: dict = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||
)
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict, metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
for fld in fields(self):
|
||||
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
||||
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from astrai.dataset.dataset import (
|
||||
BaseDataset,
|
||||
DatasetFactory,
|
||||
)
|
||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||
from astrai.dataset.storage import (
|
||||
H5Store,
|
||||
MmapStore,
|
||||
Store,
|
||||
StoreFactory,
|
||||
detect_format,
|
||||
load_bin,
|
||||
load_h5,
|
||||
save_bin,
|
||||
save_h5,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
"DatasetFactory",
|
||||
"Store",
|
||||
"StoreFactory",
|
||||
"H5Store",
|
||||
"MmapStore",
|
||||
"detect_format",
|
||||
"save_h5",
|
||||
"load_h5",
|
||||
"save_bin",
|
||||
"load_bin",
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
"""Dataset implementations with factory pattern for training."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.dataset.storage import (
|
||||
Store,
|
||||
StoreFactory,
|
||||
detect_format,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
"""Abstract base class for all dataset types.
|
||||
|
||||
Implements common functionality for window-based data fetching.
|
||||
Uses a storage abstraction for format-agnostic data loading.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.storage: Optional[Store] = None
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
"""Return required storage keys for this dataset type.
|
||||
|
||||
Subclasses should override to specify expected keys.
|
||||
"""
|
||||
return []
|
||||
|
||||
def _validate_keys(self):
|
||||
if not self.required_keys:
|
||||
return
|
||||
actual_keys = set(self.storage.keys)
|
||||
missing = [k for k in self.required_keys if k not in actual_keys]
|
||||
if missing:
|
||||
raise KeyError(
|
||||
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
|
||||
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
|
||||
f"Missing: {missing}"
|
||||
)
|
||||
|
||||
def load(self, load_path: str, storage_type: Optional[str] = None):
|
||||
"""Load dataset from the given path.
|
||||
|
||||
Auto-detects the storage format if not specified.
|
||||
|
||||
Args:
|
||||
load_path: Path to the data directory or file
|
||||
storage_type: Force a specific storage type ("h5", "bin"),
|
||||
or None for auto-detection
|
||||
|
||||
Raises:
|
||||
KeyError: If the loaded storage is missing required keys.
|
||||
"""
|
||||
if storage_type is None:
|
||||
storage_type = detect_format(load_path)
|
||||
self.storage = StoreFactory.create(storage_type)
|
||||
self._load_path = load_path
|
||||
self.storage.load(load_path)
|
||||
self._validate_keys()
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||
if self.storage is None:
|
||||
return 0
|
||||
return len(self.storage)
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
"""Return the available data keys."""
|
||||
if self.storage is None:
|
||||
return []
|
||||
return self.storage.keys
|
||||
|
||||
def get_index(self, index: int) -> tuple:
|
||||
"""Calculate begin and end indices for a sample.
|
||||
|
||||
Args:
|
||||
index: Sample index
|
||||
|
||||
Returns:
|
||||
Tuple of (begin_idx, end_idx)
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise RuntimeError("Dataset not loaded, call load() first")
|
||||
total = len(self.storage)
|
||||
if total <= self.window_size:
|
||||
raise ValueError(
|
||||
f"Data too short: {total} tokens <= window_size {self.window_size}"
|
||||
)
|
||||
|
||||
begin_idx = min(index * self.stride, total - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, total - 1)
|
||||
|
||||
return begin_idx, end_idx
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Get a single sample by index.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.storage is None:
|
||||
return 0
|
||||
total = len(self.storage)
|
||||
if total <= self.window_size:
|
||||
return 0
|
||||
return (total - 1 - self.window_size) // self.stride + 1
|
||||
|
||||
|
||||
class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||
"""Factory class for creating dataset instances.
|
||||
|
||||
Supports decorator-based registration for extensible dataset types.
|
||||
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
||||
when their classes are defined with the decorator.
|
||||
|
||||
Example usage:
|
||||
@DatasetFactory.register("custom")
|
||||
class CustomDataset(BaseDataset):
|
||||
...
|
||||
|
||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, dataset_cls: type):
|
||||
"""Validate that the dataset class inherits from BaseDataset."""
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
||||
"""Create a dataset instance.
|
||||
|
||||
Args:
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples
|
||||
|
||||
Returns:
|
||||
Dataset instance
|
||||
"""
|
||||
return super().create(train_type, window_size, stride)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
train_type: str,
|
||||
load_path: str,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
) -> "BaseDataset":
|
||||
"""Create and load a dataset in one step.
|
||||
|
||||
Args:
|
||||
train_type: Type of training dataset
|
||||
load_path: Path to the data file
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples (default: same as window_size)
|
||||
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
||||
|
||||
Returns:
|
||||
Loaded dataset instance
|
||||
"""
|
||||
if stride is None:
|
||||
stride = window_size
|
||||
|
||||
dataset = cls.create(train_type, window_size, stride)
|
||||
dataset.load(load_path, storage_type=storage_type)
|
||||
|
||||
return dataset
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered dataset type names."""
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
@DatasetFactory.register("seq")
|
||||
class SEQDataset(BaseDataset):
|
||||
"""Dataset for sequential next-token prediction training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["sequence"]
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||
|
||||
return {"input_ids": x, "target_ids": y}
|
||||
|
||||
|
||||
@DatasetFactory.register("sft")
|
||||
class SFTDataset(BaseDataset):
|
||||
"""Dataset for supervised fine-tuning with loss masking."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["sequence", "loss_mask"]
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||
|
||||
|
||||
@DatasetFactory.register("dpo")
|
||||
class DPODataset(BaseDataset):
|
||||
"""Dataset for Direct Preference Optimization training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return {
|
||||
"chosen": chosen,
|
||||
"rejected": rejected,
|
||||
"chosen_mask": chosen_mask,
|
||||
"rejected_mask": rejected_mask,
|
||||
}
|
||||
|
||||
|
||||
@DatasetFactory.register("grpo")
|
||||
class GRPODataset(BaseDataset):
|
||||
"""Dataset for Group Relative Policy Optimization training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["prompts", "responses", "masks", "rewards"]
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
||||
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||
|
||||
return {
|
||||
"prompts": prompts,
|
||||
"responses": responses,
|
||||
"masks": masks,
|
||||
"rewards": rewards,
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
|
||||
class ResumableDistributedSampler(Sampler[int]):
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Dataset,
|
||||
start_epoch: int = 0,
|
||||
start_iter: int = 0,
|
||||
seed: int = 42,
|
||||
drop_last: bool = False,
|
||||
shuffle: bool = True,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
self.epoch = start_epoch
|
||||
self.iter = start_iter
|
||||
self.seed = seed
|
||||
self.num_samples = len(data_source)
|
||||
|
||||
if process_group is not None:
|
||||
# input process group
|
||||
self.rank = dist.get_rank(process_group)
|
||||
self.num_replicas = dist.get_world_size(process_group)
|
||||
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
# use default process group
|
||||
process_group = dist.group.WORLD
|
||||
self.rank = dist.get_rank()
|
||||
self.num_replicas = dist.get_world_size()
|
||||
|
||||
else:
|
||||
# single process
|
||||
self.rank = 0
|
||||
self.num_replicas = 1
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.shuffle = shuffle
|
||||
|
||||
offset = 0 if drop_last else self.num_replicas - 1
|
||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||
self.iter = self.iter % self.num_samples_per_replica
|
||||
|
||||
self._indices = None
|
||||
|
||||
def _get_indices(self):
|
||||
if self.shuffle:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(self.seed + self.epoch)
|
||||
indices = torch.randperm(self.num_samples, generator=generator).tolist()
|
||||
else:
|
||||
indices = torch.arange(self.num_samples).tolist()
|
||||
|
||||
if not self.drop_last and self.num_samples < self.total_size:
|
||||
padding_size = self.total_size - len(indices)
|
||||
indices += indices[:padding_size]
|
||||
|
||||
local_indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
|
||||
self.iter = self.iter % self.num_samples_per_replica
|
||||
self._indices = local_indices[self.iter :]
|
||||
|
||||
def __iter__(self):
|
||||
if self._indices is None:
|
||||
self._get_indices()
|
||||
|
||||
for i in self._indices:
|
||||
self.iter += 1
|
||||
yield i
|
||||
|
||||
self.epoch += 1
|
||||
self._indices = None
|
||||
|
||||
@property
|
||||
def _remaining(self):
|
||||
remaining = self.num_samples_per_replica - self.iter
|
||||
return max(remaining, 0)
|
||||
|
||||
def __len__(self):
|
||||
return self._remaining
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
"""Storage backends for different data formats.
|
||||
|
||||
Layers:
|
||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
||||
return Dict[str, List[Tensor]] — format-specific, no state
|
||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||
Dict[str, List[Tensor]] per key via _normalize(),
|
||||
fetch() uses bisect across segments — no forced concat
|
||||
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||
|
||||
Key properties:
|
||||
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||
datasets larger than RAM
|
||||
- Explicit length: _length = min(total elements across keys), set at load,
|
||||
__len__ returns O(1)
|
||||
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||
workers share OS page-cache pages
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, "w") as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f"data_{idx}", data=arr)
|
||||
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, "r") as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
|
||||
|
||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
meta = {}
|
||||
for key, tensors in tensor_group.items():
|
||||
cat = torch.cat(tensors, dim=0)
|
||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
||||
json.dump(meta, f)
|
||||
|
||||
|
||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
||||
meta = json.load(f)
|
||||
segments: Dict[str, List[Tensor]] = {}
|
||||
for key, info in meta.items():
|
||||
arr = np.memmap(
|
||||
os.path.join(file_path, f"{key}.bin"),
|
||||
dtype=info["dtype"],
|
||||
mode="r+",
|
||||
shape=tuple(info["shape"]),
|
||||
)
|
||||
segments[key] = [torch.from_numpy(arr)]
|
||||
return segments
|
||||
|
||||
|
||||
def detect_format(load_path: str) -> str:
|
||||
"""Auto-detect storage format from files in the directory.
|
||||
|
||||
Args:
|
||||
load_path: Directory or file path
|
||||
|
||||
Returns:
|
||||
Format string ("h5" or "bin")
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If no supported data files are found
|
||||
"""
|
||||
root = Path(load_path)
|
||||
if root.is_file():
|
||||
suffix = root.suffix.lower()
|
||||
if suffix in (".h5", ".hdf5"):
|
||||
return "h5"
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
|
||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||
if h5_files:
|
||||
return "h5"
|
||||
bin_files = list(root.rglob("*.bin"))
|
||||
if bin_files:
|
||||
has_meta = (root / "meta.json").exists() or len(
|
||||
list(root.rglob("meta.json"))
|
||||
) > 0
|
||||
if has_meta:
|
||||
return "bin"
|
||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||
|
||||
|
||||
class Store(ABC):
|
||||
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
||||
|
||||
Each key maps to one or more tensor segments (no forced concatenation).
|
||||
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
||||
total element count across all keys.
|
||||
|
||||
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
||||
via ``_normalize()``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._data: Dict[str, List[Tensor]] = {}
|
||||
self._cum: Dict[str, List[int]] = {}
|
||||
self._length: int = 0
|
||||
|
||||
@abstractmethod
|
||||
def load(self, path: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
begin: int,
|
||||
end: int,
|
||||
keys: Union[str, List[str]],
|
||||
):
|
||||
if not self._data:
|
||||
raise RuntimeError("Store not loaded")
|
||||
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
||||
raise ValueError(
|
||||
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
||||
)
|
||||
if isinstance(keys, str):
|
||||
return self._fetch_key(keys, begin, end)
|
||||
return {k: self._fetch_key(k, begin, end) for k in keys}
|
||||
|
||||
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
||||
"""Fetch slice [begin, end) across potentially multiple segments."""
|
||||
segments = self._data[key]
|
||||
cum = self._cum[key]
|
||||
seg_start = bisect.bisect_right(cum, begin)
|
||||
seg_end = bisect.bisect_left(cum, end)
|
||||
|
||||
results = []
|
||||
for i in range(seg_start, seg_end + 1):
|
||||
prev = cum[i - 1] if i > 0 else 0
|
||||
s = max(begin - prev, 0)
|
||||
e = min(end - prev, segments[i].shape[0])
|
||||
results.append(segments[i][s:e])
|
||||
|
||||
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
||||
|
||||
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
||||
"""Register segments and pre-compute cumulative lengths.
|
||||
|
||||
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
||||
large datasets. Sets ``self._length`` to the minimum total
|
||||
element count across all keys.
|
||||
"""
|
||||
for key, tensors in raw.items():
|
||||
self._data[key] = tensors
|
||||
cum = []
|
||||
total = 0
|
||||
for t in tensors:
|
||||
total += t.shape[0]
|
||||
cum.append(total)
|
||||
self._cum[key] = cum
|
||||
self._length = (
|
||||
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
||||
if self._cum
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
class StoreFactory(BaseFactory["Store"]):
|
||||
"""Factory for creating Store instances by type name.
|
||||
|
||||
Example::
|
||||
|
||||
@StoreFactory.register("custom")
|
||||
class CustomStore(Store):
|
||||
...
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, store_cls: type):
|
||||
if not issubclass(store_cls, Store):
|
||||
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
||||
|
||||
|
||||
@StoreFactory.register("h5")
|
||||
class H5Store(Store):
|
||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||
|
||||
def load(self, path: str):
|
||||
self._normalize(load_h5(path))
|
||||
|
||||
|
||||
@StoreFactory.register("bin")
|
||||
class MmapStore(Store):
|
||||
"""Memory-mapped binary storage backend.
|
||||
|
||||
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
||||
No per-process memory duplication — all DataLoader workers share the
|
||||
same OS page-cache pages.
|
||||
|
||||
Format on disk::
|
||||
|
||||
data_root/
|
||||
meta.json # {key: {shape, dtype}, ...}
|
||||
<key>.bin # raw numpy array, one per key
|
||||
"""
|
||||
|
||||
def load(self, path: str):
|
||||
self._mmap_refs = []
|
||||
root = Path(path)
|
||||
all_raw: Dict[str, List[Tensor]] = {}
|
||||
meta_paths = list(root.rglob("meta.json"))
|
||||
for meta_path in meta_paths:
|
||||
raw = load_bin(str(meta_path.parent))
|
||||
for key, tensors in raw.items():
|
||||
if key not in all_raw:
|
||||
all_raw[key] = []
|
||||
all_raw[key].extend(tensors)
|
||||
if not meta_paths:
|
||||
raise FileNotFoundError(f"No meta.json found under {path}")
|
||||
self._normalize(all_raw)
|
||||
for tensors in self._data.values():
|
||||
self._mmap_refs.extend(tensors)
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Flexible registry for component classes with category and priority support.
|
||||
|
||||
This registry stores component classes with optional metadata (category, priority).
|
||||
It provides methods for registration, retrieval, and listing with filtering.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._entries = {} # name -> (component_cls, category, priority)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
component_cls: Type,
|
||||
category: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
"""Register a component class with optional category and priority."""
|
||||
if name in self._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
self._entries[name] = (component_cls, category, priority)
|
||||
|
||||
def get(self, name: str) -> Type:
|
||||
"""Get component class by name."""
|
||||
if name not in self._entries:
|
||||
raise KeyError(f"Component '{name}' not found in registry")
|
||||
return self._entries[name][0]
|
||||
|
||||
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
|
||||
"""Get component class with its metadata."""
|
||||
entry = self._entries.get(name)
|
||||
if entry is None:
|
||||
raise KeyError(f"Component '{name}' not found in registry")
|
||||
return entry
|
||||
|
||||
def contains(self, name: str) -> bool:
|
||||
"""Check if a name is registered."""
|
||||
return name in self._entries
|
||||
|
||||
def list_names(self) -> List[str]:
|
||||
"""Return list of registered component names."""
|
||||
return sorted(self._entries.keys())
|
||||
|
||||
def list_by_category(self, category: str) -> List[str]:
|
||||
"""Return names of components belonging to a specific category."""
|
||||
return sorted(
|
||||
name for name, (_, cat, _) in self._entries.items() if cat == category
|
||||
)
|
||||
|
||||
def list_by_priority(self, reverse: bool = False) -> List[str]:
|
||||
"""Return names sorted by priority (default ascending)."""
|
||||
return sorted(
|
||||
self._entries.keys(),
|
||||
key=lambda name: self._entries[name][2],
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
|
||||
"""Return raw entries dictionary."""
|
||||
return self._entries.copy()
|
||||
|
||||
|
||||
class BaseFactory(ABC, Generic[T]):
|
||||
"""Generic factory class for component registration and creation.
|
||||
|
||||
This base class provides a decorator-based registration pattern
|
||||
for creating extensible component factories.
|
||||
|
||||
Example usage:
|
||||
class MyFactory(BaseFactory[MyBaseClass]):
|
||||
pass
|
||||
|
||||
@MyFactory.register("custom")
|
||||
class CustomComponent(MyBaseClass):
|
||||
...
|
||||
|
||||
component = MyFactory.create("custom", *args, **kwargs)
|
||||
"""
|
||||
|
||||
_registry: Registry
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._registry = Registry()
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, name: str, category: Optional[str] = None, priority: int = 0
|
||||
) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Decorator to register a component class with optional category and priority.
|
||||
|
||||
Args:
|
||||
name: Registration name for the component
|
||||
category: Optional category for grouping components
|
||||
priority: Priority for ordering (default 0)
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the component class
|
||||
|
||||
Raises:
|
||||
TypeError: If the decorated class doesn't inherit from the base type
|
||||
"""
|
||||
|
||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||
cls._validate_component(component_cls)
|
||||
cls._registry.register(
|
||||
name, component_cls, category=category, priority=priority
|
||||
)
|
||||
return component_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, *args, **kwargs) -> T:
|
||||
"""Create a component instance by name.
|
||||
|
||||
Filters kwargs to match the component's __init__ signature,
|
||||
so components don't need to declare **kwargs just to absorb
|
||||
parameters meant for other components.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
*args: Positional arguments passed to component constructor
|
||||
**kwargs: Keyword arguments passed to component constructor
|
||||
|
||||
Returns:
|
||||
Component instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if not cls._registry.contains(name):
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
)
|
||||
component_cls = cls._registry.get(name)
|
||||
sig = inspect.signature(component_cls.__init__)
|
||||
has_var_kwargs = any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||
)
|
||||
if not has_var_kwargs:
|
||||
valid = {
|
||||
p.name
|
||||
for p in sig.parameters.values()
|
||||
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
||||
}
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
||||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]):
|
||||
"""Validate that the component class is valid for this factory.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
||||
Args:
|
||||
component_cls: Component class to validate
|
||||
|
||||
Raises:
|
||||
TypeError: If the component class is invalid
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_component_class(cls, name: str) -> Type[T]:
|
||||
"""Get the registered component class by name without instantiating it.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
|
||||
Returns:
|
||||
The component class itself
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if not cls._registry.contains(name):
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
)
|
||||
return cls._registry.get(name)
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
"""List all registered component names.
|
||||
|
||||
Returns:
|
||||
List of registered component names
|
||||
"""
|
||||
return cls._registry.list_names()
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, name: str) -> bool:
|
||||
"""Check if a component name is registered.
|
||||
|
||||
Args:
|
||||
name: Component name to check
|
||||
|
||||
Returns:
|
||||
True if registered, False otherwise
|
||||
"""
|
||||
return cls._registry.contains(name)
|
||||
|
||||
@classmethod
|
||||
def list_by_category(cls, category: str) -> List[str]:
|
||||
"""List registered component names in a category."""
|
||||
return cls._registry.list_by_category(category)
|
||||
|
||||
@classmethod
|
||||
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
||||
"""List registered component names sorted by priority."""
|
||||
return cls._registry.list_by_priority(reverse)
|
||||
|
||||
|
||||
__all__ = ["Registry", "BaseFactory"]
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Inference module for continuous batching.
|
||||
|
||||
Layers:
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP orchestration (ProtocolHandler, server)
|
||||
- protocols/: Response builders (OpenAI, Anthropic)
|
||||
- transport/: SSE transport utilities
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
"""
|
||||
|
||||
from astrai.inference.api import (
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
GenContext,
|
||||
MessagesRequest,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.core import (
|
||||
STOP,
|
||||
Allocator,
|
||||
Executor,
|
||||
InferenceScheduler,
|
||||
KVCache,
|
||||
KvcacheView,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
Task,
|
||||
TaskManager,
|
||||
TaskStatus,
|
||||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||
from astrai.inference.sample import (
|
||||
BaseSamplingStrategy,
|
||||
SamplingPipeline,
|
||||
TemperatureStrategy,
|
||||
TopKStrategy,
|
||||
TopPStrategy,
|
||||
sample,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"InferenceEngine",
|
||||
"GenerationRequest",
|
||||
"InferenceScheduler",
|
||||
"Executor",
|
||||
"STOP",
|
||||
"Task",
|
||||
"TaskManager",
|
||||
"TaskStatus",
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
"PagePool",
|
||||
"PrefixCache",
|
||||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
"sample",
|
||||
"BaseSamplingStrategy",
|
||||
"TemperatureStrategy",
|
||||
"TopKStrategy",
|
||||
"TopPStrategy",
|
||||
"SamplingPipeline",
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"GenContext",
|
||||
"OpenAIResponseBuilder",
|
||||
"AnthropicResponseBuilder",
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"AnthropicMessage",
|
||||
"MessagesRequest",
|
||||
"app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
||||
|
||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||
from astrai.inference.api.server import (
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
MessagesRequest,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"GenContext",
|
||||
"AnthropicMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatMessage",
|
||||
"MessagesRequest",
|
||||
"app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
"""Anthropic message completion response builder."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class AnthropicResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in request.messages:
|
||||
text = _extract_text(m.content)
|
||||
if text:
|
||||
messages.append({"role": m.role, "content": text})
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
ctx = GenContext(
|
||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
events: List[str] = []
|
||||
if stop.matched:
|
||||
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
||||
unyielded = trimmed[len(stop.yielded) :]
|
||||
if unyielded:
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": unyielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
if stop.matched:
|
||||
content = content[: content.rfind(stop.matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
"""OpenAI chat completion response builder."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNSUPPORTED_PARAMS = (
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
)
|
||||
|
||||
|
||||
class OpenAIResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
self._model = request.model
|
||||
|
||||
for param in _UNSUPPORTED_PARAMS:
|
||||
value = getattr(request, param, None)
|
||||
fields = getattr(type(request), "model_fields", {})
|
||||
default = fields[param].default if param in fields else None
|
||||
if value is not None and value != default:
|
||||
logger.warning(
|
||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||
param,
|
||||
value,
|
||||
)
|
||||
if value is not None and value != default:
|
||||
logger.warning(
|
||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||
param,
|
||||
value,
|
||||
)
|
||||
|
||||
ctx = GenContext(
|
||||
resp_id=self._resp_id,
|
||||
created=int(time.time()),
|
||||
model=self._model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop = request.stop
|
||||
stop_sequences = (
|
||||
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
)
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
||||
|
||||
ProtocolHandler orchestrates the async generation loop and delegates
|
||||
protocol-specific formatting to a ResponseBuilder.
|
||||
"""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
lines: List[str] = []
|
||||
if event:
|
||||
lines.append(f"event: {event}")
|
||||
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def sse_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenContext:
|
||||
"""Per-generation metadata passed to builder format methods."""
|
||||
|
||||
resp_id: str
|
||||
created: int
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopInfo:
|
||||
"""Stop-check result passed to format_stream_end / format_response."""
|
||||
|
||||
matched: Optional[str] = None
|
||||
body: str = ""
|
||||
yielded: str = ""
|
||||
|
||||
|
||||
class StopChecker:
|
||||
"""Scans accumulated text for stop sequence matches."""
|
||||
|
||||
def __init__(self, sequences: List[str]):
|
||||
self._sequences = [s for s in sequences if s]
|
||||
|
||||
def check(self, text: str) -> Optional[str]:
|
||||
for seq in self._sequences:
|
||||
if seq in text:
|
||||
return seq
|
||||
return None
|
||||
|
||||
|
||||
class ResponseBuilder(ABC):
|
||||
"""Interface for protocol-specific response formatting.
|
||||
|
||||
A new protocol requires one concrete builder implementing 5 methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
"""SSE events that open the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_chunk(self, token: str) -> str:
|
||||
"""SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
"""SSE events that close the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
"""JSON response body for non-streaming mode."""
|
||||
|
||||
|
||||
class ProtocolHandler:
|
||||
"""Orchestrates the generation loop, delegates formatting to a builder.
|
||||
|
||||
Usage::
|
||||
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
response = await handler.handle()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
||||
):
|
||||
self.request = request
|
||||
self.engine = engine
|
||||
self.builder = builder
|
||||
|
||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
||||
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
||||
|
||||
agen = self.engine.generate_async(
|
||||
prompt=prompt,
|
||||
max_tokens=self.request.max_tokens,
|
||||
temperature=self.request.temperature,
|
||||
top_p=self.request.top_p,
|
||||
top_k=self.request.top_k,
|
||||
)
|
||||
|
||||
if self.request.stream:
|
||||
return self._handle_stream(agen, ctx, stop_sequences)
|
||||
else:
|
||||
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
||||
|
||||
def _handle_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> StreamingResponse:
|
||||
checker = StopChecker(stop_sequences)
|
||||
|
||||
async def event_stream():
|
||||
for event in self.builder.format_stream_start(ctx):
|
||||
yield event
|
||||
|
||||
body = ""
|
||||
yielded = ""
|
||||
matched = None
|
||||
async for token in agen:
|
||||
body += token
|
||||
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
yield self.builder.format_chunk(token)
|
||||
yielded += token
|
||||
|
||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||
for event in self.builder.format_stream_end(ctx, stop):
|
||||
yield event
|
||||
yield sse_done()
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
async def _handle_non_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
checker = StopChecker(stop_sequences)
|
||||
chunks: List[str] = []
|
||||
body = ""
|
||||
matched = None
|
||||
|
||||
async for token in agen:
|
||||
chunks.append(token)
|
||||
body += token
|
||||
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
|
||||
content = "".join(chunks)
|
||||
stop = StopInfo(matched=matched, body=body)
|
||||
return self.builder.format_response(ctx, content, stop)
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
||||
|
||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.api.protocol import ProtocolHandler
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""OpenAI Chat Completion API request body."""
|
||||
|
||||
model: str = "astrai"
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=50, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||
n: Optional[int] = Field(default=1, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
role: str
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class MessagesRequest(BaseModel):
|
||||
"""Anthropic Messages API request body."""
|
||||
|
||||
model: str = "astrai"
|
||||
max_tokens: int = Field(default=1024, ge=1)
|
||||
messages: List[AnthropicMessage]
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=50, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = app.state.server_config
|
||||
if not config.get("_test", False):
|
||||
try:
|
||||
app.state.engine = _create_engine(**config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if app.state.engine:
|
||||
app.state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _create_engine(
|
||||
param_path: Optional[Path] = None,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
max_batch_size: int = 16,
|
||||
) -> InferenceEngine:
|
||||
if param_path is None:
|
||||
param_path = _project_root / "params"
|
||||
if not param_path.exists():
|
||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
model.to(device=device, dtype=dtype)
|
||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||
|
||||
engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||
return engine
|
||||
|
||||
|
||||
def _get_engine() -> InferenceEngine:
|
||||
engine = app.state.engine
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return engine
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": app.state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
return _get_engine().get_stats()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
engine = _get_engine()
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest):
|
||||
engine = _get_engine()
|
||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
def run_server(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
reload: bool = False,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
app.state.server_config = {
|
||||
"device": device,
|
||||
"dtype": dtype,
|
||||
"param_path": param_path,
|
||||
"max_batch_size": max_batch_size,
|
||||
}
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
)
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""Inference core: cache, executor, scheduler, task management."""
|
||||
|
||||
from astrai.inference.core.cache import (
|
||||
Allocator,
|
||||
KVCache,
|
||||
KvcacheView,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
from astrai.inference.core.executor import Executor
|
||||
from astrai.inference.core.scheduler import InferenceScheduler
|
||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||
|
||||
__all__ = [
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
"PagePool",
|
||||
"PrefixCache",
|
||||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
"Executor",
|
||||
"InferenceScheduler",
|
||||
"STOP",
|
||||
"Task",
|
||||
"TaskManager",
|
||||
"TaskStatus",
|
||||
]
|
||||
|
|
@ -0,0 +1,368 @@
|
|||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||
start = page_idx * page_size
|
||||
end = min(start + page_size, len(token_ids))
|
||||
h = 0
|
||||
for i in range(start, end):
|
||||
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
|
||||
return h
|
||||
|
||||
|
||||
class Allocator:
|
||||
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
|
||||
|
||||
def __init__(self, n_pages: int):
|
||||
self._free_mask = (1 << n_pages) - 1
|
||||
self._refs: List[int] = [0] * n_pages
|
||||
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||
self.on_evict: Optional[Callable[[int], None]] = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def alloc(self) -> int:
|
||||
with self._lock:
|
||||
if self._free_mask:
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_mask ^= lsb
|
||||
self._refs[idx] = 1
|
||||
return idx
|
||||
if self._lru:
|
||||
idx, _ = self._lru.popitem(last=False)
|
||||
if self.on_evict:
|
||||
self.on_evict(idx)
|
||||
self._refs[idx] = 1
|
||||
self._free_mask &= ~(1 << idx)
|
||||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False):
|
||||
with self._lock:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
if keep_cached:
|
||||
self._lru[idx] = None
|
||||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int):
|
||||
with self._lock:
|
||||
self._refs[idx] += 1
|
||||
self._lru.pop(idx, None)
|
||||
|
||||
def ref_count(self, idx: int) -> int:
|
||||
with self._lock:
|
||||
return self._refs[idx]
|
||||
|
||||
def touch(self, idx: int):
|
||||
with self._lock:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
|
||||
class PrefixCache:
|
||||
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||
|
||||
def __init__(self, page_size: int):
|
||||
self._page_size = page_size
|
||||
self._page_to_hash: Dict[int, int] = {}
|
||||
self._hash_to_page: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def evict(self, idx: int):
|
||||
with self._lock:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
self._hash_to_page.pop(h, None)
|
||||
|
||||
def has_page(self, idx: int) -> bool:
|
||||
with self._lock:
|
||||
return idx in self._page_to_hash
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
with self._lock:
|
||||
full_pages = len(token_ids) // self._page_size
|
||||
hits: List[int] = []
|
||||
for i in range(full_pages):
|
||||
h = page_hash(token_ids, i, self._page_size)
|
||||
p = self._hash_to_page.get(h)
|
||||
if p is None:
|
||||
break
|
||||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
with self._lock:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
if old_h is not None:
|
||||
self._hash_to_page.pop(old_h, None)
|
||||
self._page_to_hash[page_idx] = h
|
||||
self._hash_to_page[h] = page_idx
|
||||
|
||||
|
||||
class PagePool:
|
||||
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
|
||||
|
||||
def __init__(self, allocator: Allocator, prefix: PrefixCache):
|
||||
self._alloc = allocator
|
||||
self._prefix = prefix
|
||||
self._alloc.on_evict = prefix.evict
|
||||
|
||||
@property
|
||||
def allocator(self) -> Allocator:
|
||||
return self._alloc
|
||||
|
||||
@property
|
||||
def prefix(self) -> PrefixCache:
|
||||
return self._prefix
|
||||
|
||||
def alloc(self) -> int:
|
||||
return self._alloc.alloc()
|
||||
|
||||
def free(self, idx: int):
|
||||
keep = self._prefix.has_page(idx)
|
||||
self._alloc.free(idx, keep_cached=keep)
|
||||
if not keep:
|
||||
self._prefix.evict(idx)
|
||||
|
||||
def inc_ref(self, idx: int):
|
||||
self._alloc.inc_ref(idx)
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
hits = self._prefix.lookup(token_ids)
|
||||
for p in hits:
|
||||
self._alloc.touch(p)
|
||||
return hits
|
||||
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||
|
||||
|
||||
class TaskTable:
|
||||
"""Maps task_ids to page tables and cached token counts."""
|
||||
|
||||
def __init__(self, page_size: int):
|
||||
self._page_size = page_size
|
||||
self._pages: Dict[str, List[int]] = {}
|
||||
self._cached: Dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int):
|
||||
with self._lock:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
||||
def get(self, task_id: str) -> List[int]:
|
||||
with self._lock:
|
||||
return self._pages.get(task_id, [])
|
||||
|
||||
def get_cached(self, task_id: str) -> int:
|
||||
with self._lock:
|
||||
return self._cached.get(task_id, 0)
|
||||
|
||||
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||
with self._lock:
|
||||
pages = self._pages.pop(task_id, [])
|
||||
cached = self._cached.pop(task_id, 0)
|
||||
return pages, cached
|
||||
|
||||
def get_ref(self, task_id: str) -> List[int]:
|
||||
with self._lock:
|
||||
return self._pages.setdefault(task_id, [])
|
||||
|
||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
with self._lock:
|
||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||
max_pages = max((len(s) for s in states), default=0)
|
||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
class Storage:
|
||||
"""KV-cache tensor storage with paged write/gather."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
n_pages: int,
|
||||
page_size: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self.k_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.v_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
layer_id: int,
|
||||
page_table: Tensor,
|
||||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
):
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
page_size = self.page_size
|
||||
written = 0
|
||||
first_page = start_pos // page_size
|
||||
last_page = (start_pos + seq_len - 1) // page_size
|
||||
for pi in range(first_page, last_page + 1):
|
||||
phys_pages = page_table[:, pi]
|
||||
page_start = pi * page_size
|
||||
write_start = max(page_start, start_pos)
|
||||
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||
offset = write_start - page_start
|
||||
chunk = write_end - write_start
|
||||
valid = phys_pages >= 0
|
||||
if not valid.all():
|
||||
if valid.any():
|
||||
valid_pages = phys_pages[valid]
|
||||
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
|
||||
valid, written : written + chunk
|
||||
]
|
||||
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
|
||||
valid, written : written + chunk
|
||||
]
|
||||
written += chunk
|
||||
continue
|
||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||
:, written : written + chunk
|
||||
]
|
||||
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
||||
:, written : written + chunk
|
||||
]
|
||||
written += chunk
|
||||
|
||||
def gather(
|
||||
self, layer_id: int, page_table: Tensor, total_len: int
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
safe = page_table.clamp(min=0)
|
||||
k = self.k_cache[layer_id, safe]
|
||||
v = self.v_cache[layer_id, safe]
|
||||
k = k.flatten(1, 2)
|
||||
v = v.flatten(1, 2)
|
||||
if (page_table < 0).any():
|
||||
invalid = (
|
||||
(page_table < 0)
|
||||
.unsqueeze(-1)
|
||||
.expand(-1, -1, self.page_size)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
invalid = invalid[:, :, None, None].expand_as(k)
|
||||
k = k.masked_fill(invalid, 0.0)
|
||||
v = v.masked_fill(invalid, 0.0)
|
||||
k = k[:, :total_len]
|
||||
v = v[:, :total_len]
|
||||
return k, v
|
||||
|
||||
|
||||
class KvcacheView:
|
||||
"""Bundles Storage + page_table + total_len for attention layers."""
|
||||
|
||||
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
|
||||
self._storage = storage
|
||||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
||||
start_pos = self._total_len - k.size(1)
|
||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
return self._storage.gather(layer_id, self._page_table, self._total_len)
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""Facade: page management + KV-cache I/O for continuous batching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
n_pages: int,
|
||||
page_size: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||
self._table = TaskTable(page_size)
|
||||
self._storage = Storage(
|
||||
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
|
||||
)
|
||||
|
||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||
hits = self._pool.lookup(prompt_ids)
|
||||
cached = len(hits) * self.page_size
|
||||
for p in hits:
|
||||
self._pool.inc_ref(p)
|
||||
|
||||
remaining = len(prompt_ids) - cached
|
||||
n_new = (
|
||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||
)
|
||||
new_pages: List[int] = []
|
||||
if n_new > 0:
|
||||
for _ in range(n_new):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for hp in hits:
|
||||
self._pool.free(hp)
|
||||
for np in new_pages:
|
||||
self._pool.free(np)
|
||||
return False
|
||||
new_pages.append(p)
|
||||
|
||||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str):
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self._pool.free(idx)
|
||||
|
||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||
page_table = self._table.get(task_id)
|
||||
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
||||
while len(page_table) < needed:
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
return False
|
||||
page_table.append(p)
|
||||
return True
|
||||
|
||||
def task_cached(self, task_id: str) -> int:
|
||||
return self._table.get_cached(task_id)
|
||||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
):
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
self._pool.record(page_table[i], prompt_ids, i)
|
||||
|
||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
return self._table.table_tensor(task_ids, device)
|
||||
|
||||
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
|
||||
return KvcacheView(self._storage, page_table, total_len)
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.inference.core.cache import KVCache
|
||||
from astrai.inference.core.task import Task
|
||||
from astrai.inference.sample import sample
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Executor:
|
||||
"""Model forward passes for prefill and decode phases."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
page_cache: KVCache,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.page_cache = page_cache
|
||||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
||||
if start_pos >= prompt_len:
|
||||
return
|
||||
|
||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||
batch_sz = len(tasks)
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
task_ids = [t.task_id for t in tasks]
|
||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
position_ids=torch.arange(
|
||||
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_sz, -1),
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
||||
if not tasks:
|
||||
return []
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
position_ids = torch.tensor(
|
||||
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
|
||||
)
|
||||
total_len = position_ids.max().item() + 1
|
||||
|
||||
task_ids = [t.task_id for t in tasks]
|
||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||
|
||||
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
||||
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
||||
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(
|
||||
input_ids.unsqueeze(1),
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||
position_ids=position_ids.unsqueeze(1),
|
||||
)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
|
||||
return sample(
|
||||
logits,
|
||||
temperature=temperatures,
|
||||
top_k=top_ks,
|
||||
top_p=top_ps,
|
||||
).tolist()
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.inference.core.cache import KVCache
|
||||
from astrai.inference.core.executor import Executor
|
||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InferenceScheduler:
|
||||
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prompt_len: int = 2048,
|
||||
page_size: int = 64,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
config = model.config
|
||||
|
||||
if max_seq_len is not None:
|
||||
self.max_seq_len = max_seq_len
|
||||
elif config.max_len is not None:
|
||||
self.max_seq_len = config.max_len
|
||||
else:
|
||||
raise ValueError(
|
||||
"max_seq_len must be provided either as argument "
|
||||
"or in model config (config.max_len)"
|
||||
)
|
||||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
n_pages = (
|
||||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||
) // page_size
|
||||
|
||||
self._page_cache = KVCache(
|
||||
config.n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
config.n_kv_heads,
|
||||
config.dim // config.n_heads,
|
||||
self.device,
|
||||
self.dtype,
|
||||
)
|
||||
|
||||
self._task_mgr = TaskManager(
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=self.max_seq_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
)
|
||||
|
||||
self._executor = Executor(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
page_cache=self._page_cache,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self._running = False
|
||||
self._fatal_error: Optional[Exception] = None
|
||||
|
||||
def add_task(self, prompt: str, **kwargs) -> str:
|
||||
return self._task_mgr.add_task(prompt, **kwargs)
|
||||
|
||||
def remove_task(self, task_id: str):
|
||||
for task in self._task_mgr.remove_task(task_id):
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self._task_mgr.get_stats()
|
||||
|
||||
def _run_generation_loop(self):
|
||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||
try:
|
||||
while self._running:
|
||||
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
||||
for task in finished:
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
active = self._task_mgr.get_active_tasks()
|
||||
available = self._task_mgr.max_batch_size - len(active)
|
||||
if available > 0:
|
||||
candidates = self._task_mgr.pull_candidates(available)
|
||||
failed = []
|
||||
for task in candidates:
|
||||
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
|
||||
self._task_mgr.activate(task)
|
||||
else:
|
||||
failed.append(task)
|
||||
if failed:
|
||||
self._task_mgr.return_to_waiting(failed)
|
||||
|
||||
if not self._task_mgr.has_work():
|
||||
self._task_mgr.wait_for_tasks(timeout=1.0)
|
||||
continue
|
||||
|
||||
to_prefill = [
|
||||
t
|
||||
for t in self._task_mgr.get_active_tasks()
|
||||
if t.output_tokens == 0
|
||||
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
||||
]
|
||||
if to_prefill:
|
||||
for t in to_prefill:
|
||||
t.input_tokens = len(t.prompt_ids)
|
||||
|
||||
groups: Dict[Tuple[int, int], List[Task]] = {}
|
||||
for t in to_prefill:
|
||||
key = (
|
||||
len(t.prompt_ids),
|
||||
self._page_cache.task_cached(t.task_id),
|
||||
)
|
||||
groups.setdefault(key, []).append(t)
|
||||
|
||||
for (prompt_len, start_pos), group in groups.items():
|
||||
self._executor.execute_prefill(group, prompt_len, start_pos)
|
||||
start_logical_page = start_pos // self._page_cache.page_size
|
||||
for t in group:
|
||||
self._page_cache.task_record_hashes(
|
||||
t.task_id,
|
||||
t.prompt_ids,
|
||||
start_logical_page=start_logical_page,
|
||||
)
|
||||
|
||||
pos_groups: Dict[int, List[Task]] = {}
|
||||
for t in self._task_mgr.get_active_tasks():
|
||||
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||
|
||||
if pos_groups:
|
||||
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
||||
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
|
||||
|
||||
valid: List[Task] = []
|
||||
for t in group:
|
||||
if self._page_cache.task_extend(t.task_id, t.next_pos):
|
||||
valid.append(t)
|
||||
else:
|
||||
t.status = TaskStatus.ABORTED
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
if valid:
|
||||
next_tokens = self._executor.execute_decode(valid)
|
||||
|
||||
for t, ntok in zip(valid, next_tokens):
|
||||
t.output_ids.append(ntok)
|
||||
t.output_tokens += 1
|
||||
pos = t.input_tokens + t.output_tokens
|
||||
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
||||
if t.stream_callback:
|
||||
t.stream_callback(
|
||||
self._task_mgr.tokenizer.decode([ntok])
|
||||
)
|
||||
if not extend_ok:
|
||||
t.status = TaskStatus.ABORTED
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
for t in valid:
|
||||
if t.is_finished(stop_ids):
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
except Exception as e:
|
||||
self._fatal_error = e
|
||||
self._running = False
|
||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
for task in self._task_mgr.get_waiting_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._task_mgr.clear_queues()
|
||||
|
||||
def start(self):
|
||||
if not self._running:
|
||||
self._running = True
|
||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||
t.start()
|
||||
self._loop_thread = t
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
self._task_mgr.wake()
|
||||
if hasattr(self, "_loop_thread"):
|
||||
self._loop_thread.join(timeout=2.0)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
for task in self._task_mgr.get_waiting_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._task_mgr.clear_queues()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Deque, Dict, List, Optional
|
||||
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STOP = object()
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task lifecycle states."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
FINISHED = "finished"
|
||||
ABORTED = "aborted"
|
||||
|
||||
|
||||
class Task:
|
||||
"""Single generation request: prompt, sampling params, output state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: str,
|
||||
prompt_ids: List[int],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
self.task_id = task_id
|
||||
self.prompt_ids = prompt_ids
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
|
||||
self.status = TaskStatus.PENDING
|
||||
self.output_ids: List[int] = []
|
||||
self.input_tokens: int = 0
|
||||
self.output_tokens: int = 0
|
||||
self.arrival_time = time.time()
|
||||
self.finish_time: Optional[float] = None
|
||||
self.stream_callback = stream_callback
|
||||
|
||||
@property
|
||||
def next_pos(self) -> int:
|
||||
return self.input_tokens + len(self.output_ids)
|
||||
|
||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
|
||||
return True
|
||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: int = 8192,
|
||||
max_prompt_len: int = 512,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_prompt_len = max_prompt_len
|
||||
|
||||
self.waiting_queue: Deque[Task] = deque()
|
||||
self.active_tasks: List[Task] = []
|
||||
|
||||
self._task_event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._total_tasks = 0
|
||||
self._total_tokens = 0
|
||||
|
||||
def add_task(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> str:
|
||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
prompt_ids = self.tokenizer.encode(prompt)
|
||||
if len(prompt_ids) > self.max_prompt_len:
|
||||
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||
|
||||
if len(prompt_ids) >= self.max_seq_len:
|
||||
if stream_callback:
|
||||
stream_callback(STOP)
|
||||
return task_id
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = self.max_seq_len - len(prompt_ids)
|
||||
else:
|
||||
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
||||
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
prompt_ids=prompt_ids,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=stream_callback,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self.waiting_queue.append(task)
|
||||
self._total_tasks += 1
|
||||
|
||||
self._task_event.set()
|
||||
return task_id
|
||||
|
||||
def remove_task(self, task_id: str) -> List[Task]:
|
||||
with self._lock:
|
||||
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||
self.waiting_queue = deque(
|
||||
t for t in self.waiting_queue if t.task_id != task_id
|
||||
)
|
||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||
return removed_active
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"total_tasks": self._total_tasks,
|
||||
"total_tokens": self._total_tokens,
|
||||
"active_tasks": len(self.active_tasks),
|
||||
"waiting_queue": len(self.waiting_queue),
|
||||
}
|
||||
|
||||
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
||||
with self._lock:
|
||||
finished = []
|
||||
for task in self.active_tasks:
|
||||
if task.status == TaskStatus.ABORTED:
|
||||
task.finish_time = time.time()
|
||||
finished.append(task)
|
||||
elif task.is_finished(stop_ids):
|
||||
task.status = TaskStatus.FINISHED
|
||||
task.finish_time = time.time()
|
||||
finished.append(task)
|
||||
self._total_tokens += task.output_tokens
|
||||
|
||||
self.active_tasks = [
|
||||
t
|
||||
for t in self.active_tasks
|
||||
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
||||
]
|
||||
return finished
|
||||
|
||||
def pull_candidates(self, n: int) -> List[Task]:
|
||||
to_add: List[Task] = []
|
||||
with self._lock:
|
||||
take = min(n, len(self.waiting_queue))
|
||||
for _ in range(take):
|
||||
to_add.append(self.waiting_queue.popleft())
|
||||
return to_add
|
||||
|
||||
def activate(self, task: Task):
|
||||
task.status = TaskStatus.RUNNING
|
||||
with self._lock:
|
||||
self.active_tasks.append(task)
|
||||
|
||||
def return_to_waiting(self, tasks: List[Task]):
|
||||
with self._lock:
|
||||
for task in reversed(tasks):
|
||||
self.waiting_queue.appendleft(task)
|
||||
|
||||
def has_work(self) -> bool:
|
||||
return bool(self.active_tasks or self.waiting_queue)
|
||||
|
||||
def wait_for_tasks(self, timeout: float = 1.0):
|
||||
with self._lock:
|
||||
if self.waiting_queue or self.active_tasks:
|
||||
return
|
||||
self._task_event.clear()
|
||||
self._task_event.wait(timeout=timeout)
|
||||
|
||||
def get_active_tasks(self) -> List[Task]:
|
||||
with self._lock:
|
||||
return list(self.active_tasks)
|
||||
|
||||
def get_waiting_tasks(self) -> List[Task]:
|
||||
with self._lock:
|
||||
return list(self.waiting_queue)
|
||||
|
||||
def clear_queues(self):
|
||||
with self._lock:
|
||||
self.waiting_queue.clear()
|
||||
self.active_tasks.clear()
|
||||
|
||||
def wake(self):
|
||||
self._task_event.set()
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
"""Unified inference engine for continuous batching."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import threading
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.inference.core.scheduler import InferenceScheduler
|
||||
from astrai.inference.core.task import STOP
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||
|
||||
def __init__(self, count: int = 1):
|
||||
self._cond = threading.Condition()
|
||||
self._event = threading.Event()
|
||||
self.tokens: List[Tuple[int, str]] = []
|
||||
self.results: List[str] = [""] * count
|
||||
self._done: List[bool] = [False] * count
|
||||
self._completed = 0
|
||||
self._total = count
|
||||
|
||||
def append(self, token: str, idx: int = 0):
|
||||
with self._cond:
|
||||
self.tokens.append((idx, token))
|
||||
if token is not STOP:
|
||||
self.results[idx] += token
|
||||
else:
|
||||
if not self._done[idx]:
|
||||
self._done[idx] = True
|
||||
self._completed += 1
|
||||
self._cond.notify_all()
|
||||
self._event.set()
|
||||
|
||||
def pop_all(self) -> List[Tuple[int, str]]:
|
||||
with self._cond:
|
||||
out = self.tokens.copy()
|
||||
self.tokens.clear()
|
||||
if not out:
|
||||
self._event.clear()
|
||||
return out
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def wait_completion(self, timeout: float = 300.0):
|
||||
with self._cond:
|
||||
if not self._cond.wait_for(
|
||||
lambda: self._completed >= self._total, timeout=timeout
|
||||
):
|
||||
raise TimeoutError(
|
||||
f"Generation timeout after {timeout}s "
|
||||
f"({self._completed}/{self._total} completed)"
|
||||
)
|
||||
|
||||
def get_results(self) -> List[str]:
|
||||
with self._cond:
|
||||
return self.results.copy()
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
||||
raise ValueError("temperature must be a positive number")
|
||||
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""Unified inference engine backed by continuous-batching scheduler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 1,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prompt_len: int = 2048,
|
||||
page_size: int = 128,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.scheduler = InferenceScheduler(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.shutdown()
|
||||
return False
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
if stream:
|
||||
return self._generate_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
else:
|
||||
return self._generate_non_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
async def _agen():
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, self._next_token, sync_gen)
|
||||
if token is None:
|
||||
break
|
||||
yield token
|
||||
|
||||
return _agen()
|
||||
|
||||
@staticmethod
|
||||
def _next_token(gen: Generator) -> Optional[str]:
|
||||
try:
|
||||
return next(gen)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def generate_with_request(
|
||||
self, request: GenerationRequest
|
||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
)
|
||||
|
||||
def _submit_tasks(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: Optional[int],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Tuple[GenerateResult, List[str]]:
|
||||
n = len(prompts)
|
||||
result = GenerateResult(count=n)
|
||||
task_ids = []
|
||||
for i, p in enumerate(prompts):
|
||||
cb = self._make_callback(result, i)
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=p,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=cb,
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
return result, task_ids
|
||||
|
||||
@staticmethod
|
||||
def _make_callback(result: GenerateResult, idx: int):
|
||||
def cb(token):
|
||||
result.append(token, idx)
|
||||
|
||||
return cb
|
||||
|
||||
def _generate_streaming(
|
||||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: Optional[int],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Generator:
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
n = len(prompts)
|
||||
remaining = n
|
||||
finished = [False] * n
|
||||
|
||||
def gen():
|
||||
nonlocal remaining
|
||||
try:
|
||||
while remaining > 0:
|
||||
items = result.pop_all()
|
||||
for idx, token in items:
|
||||
if token is STOP:
|
||||
if not finished[idx]:
|
||||
finished[idx] = True
|
||||
remaining -= 1
|
||||
else:
|
||||
yield (idx, token) if is_batch else token
|
||||
if remaining > 0:
|
||||
result.wait(timeout=0.05)
|
||||
finally:
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
|
||||
return gen()
|
||||
|
||||
def _generate_non_streaming(
|
||||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: Optional[int],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Union[str, List[str]]:
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
try:
|
||||
result.wait_completion()
|
||||
except TimeoutError:
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
raise
|
||||
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
|
||||
res = result.get_results()
|
||||
return res if is_batch else res[0]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.scheduler.get_stats()
|
||||
|
||||
def shutdown(self):
|
||||
self.scheduler.stop()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
"""Composable sampling strategies for logit transformation.
|
||||
|
||||
Implements the Strategy pattern: each sampling technique
|
||||
(temperature, top-k, top-p) is a pluggable strategy that
|
||||
can be composed into a pipeline.
|
||||
|
||||
All strategies accept both scalar and per-sample tensor
|
||||
parameters, so a single pipeline works for any batch size.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class BaseSamplingStrategy(ABC):
|
||||
"""Abstract base for a logit transformation strategy."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
"""Applies the strategy to logits.
|
||||
|
||||
Args:
|
||||
logits: Raw logits tensor (batch, vocab_size).
|
||||
filter_value: Value assigned to filtered-out positions.
|
||||
|
||||
Returns:
|
||||
Transformed logits tensor.
|
||||
"""
|
||||
|
||||
|
||||
class TemperatureStrategy(BaseSamplingStrategy):
|
||||
"""Divides logits by temperature to control randomness.
|
||||
|
||||
Args:
|
||||
temperature: Scalar or ``[batch]`` tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
||||
self.temperature = temperature
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
t = self.temperature
|
||||
if isinstance(t, Tensor):
|
||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||
t = torch.clamp(t, min=1e-8)
|
||||
if (t != 1.0).any():
|
||||
logits = logits / t
|
||||
elif t != 1.0:
|
||||
logits = logits / max(t, 1e-8)
|
||||
return logits
|
||||
|
||||
|
||||
class TopKStrategy(BaseSamplingStrategy):
|
||||
"""Keeps only the top-k logits, setting the rest to filter_value.
|
||||
|
||||
Args:
|
||||
top_k: Scalar or ``[batch]`` tensor (0 disables).
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: Union[int, Tensor] = 0):
|
||||
self.top_k = top_k
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
tk = self.top_k
|
||||
if isinstance(tk, Tensor):
|
||||
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
||||
max_k = int(tk.max().item())
|
||||
if max_k <= 0:
|
||||
return logits
|
||||
max_k = min(max_k, logits.size(-1))
|
||||
values, _ = torch.topk(logits, max_k, dim=-1)
|
||||
per_row_k = tk.clamp(max=max_k)
|
||||
thresholds = torch.full_like(logits[..., -1:], -float("inf"))
|
||||
positive = per_row_k > 0
|
||||
if positive.any():
|
||||
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
|
||||
thresholds[positive] = values[
|
||||
row_idx, per_row_k[positive] - 1
|
||||
].unsqueeze(-1)
|
||||
logits[logits < thresholds] = filter_value
|
||||
return logits
|
||||
if tk > 0:
|
||||
k = min(tk, logits.size(-1))
|
||||
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
||||
logits[logits < thresholds] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
class TopPStrategy(BaseSamplingStrategy):
|
||||
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
||||
cumulative probability exceeds top_p.
|
||||
|
||||
Args:
|
||||
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
||||
self.top_p = top_p
|
||||
|
||||
def _apply(self, logits, top_p, filter_value):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
remove = cum_probs > top_p
|
||||
remove[..., 1:] = remove[..., :-1].clone()
|
||||
remove[..., 0] = False
|
||||
mask = torch.zeros_like(logits, dtype=torch.bool)
|
||||
mask.scatter_(1, sorted_indices, remove)
|
||||
logits[mask] = filter_value
|
||||
return logits
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
tp = self.top_p
|
||||
if isinstance(tp, Tensor):
|
||||
tp = tp.to(logits.device, non_blocking=True)
|
||||
if (tp < 1.0).any():
|
||||
logits = self._apply(logits, tp.view(-1, 1), filter_value)
|
||||
elif tp < 1.0:
|
||||
logits = self._apply(logits, tp, filter_value)
|
||||
return logits
|
||||
|
||||
|
||||
class SamplingPipeline(BaseSamplingStrategy):
|
||||
"""Composes multiple sampling strategies into a single transformation.
|
||||
|
||||
Strategies are applied sequentially in the order they are provided,
|
||||
matching the original temperature -> top-k -> top-p ordering.
|
||||
|
||||
Usage::
|
||||
|
||||
pipeline = SamplingPipeline([
|
||||
TemperatureStrategy(0.8),
|
||||
TopKStrategy(50),
|
||||
TopPStrategy(0.95),
|
||||
])
|
||||
logits = pipeline.apply(logits)
|
||||
token = pipeline.sample(logits) # softmax + multinomial
|
||||
"""
|
||||
|
||||
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||
self.strategies = strategies
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
for strategy in self.strategies:
|
||||
logits = strategy.apply(logits, filter_value)
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
"""Apply strategies then sample (softmax + multinomial).
|
||||
|
||||
Args:
|
||||
logits: Raw logits ``[batch, vocab_size]``.
|
||||
|
||||
Returns:
|
||||
Sampled token IDs ``[batch]``.
|
||||
"""
|
||||
return torch.multinomial(
|
||||
torch.softmax(self.apply(logits, filter_value), dim=-1),
|
||||
num_samples=1,
|
||||
).squeeze(-1)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample(
|
||||
logits: Tensor,
|
||||
temperature: Union[float, Tensor] = 1.0,
|
||||
top_k: Union[int, Tensor] = 0,
|
||||
top_p: Union[float, Tensor] = 1.0,
|
||||
filter_value: float = -float("inf"),
|
||||
) -> Tensor:
|
||||
"""Apply sampling strategies then sample (softmax + multinomial).
|
||||
|
||||
Shortcut for ``SamplingPipeline(...).sample(logits)``.
|
||||
|
||||
Args:
|
||||
logits: Raw logits ``[batch, vocab_size]``.
|
||||
|
||||
Returns:
|
||||
Sampled token IDs ``[batch]``.
|
||||
"""
|
||||
return SamplingPipeline(
|
||||
[
|
||||
TemperatureStrategy(temperature),
|
||||
TopKStrategy(top_k),
|
||||
TopPStrategy(top_p),
|
||||
]
|
||||
).sample(logits, filter_value)
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.attention import GQA
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.lora import (
|
||||
LoRAConfig,
|
||||
inject_lora,
|
||||
load_lora,
|
||||
merge_lora,
|
||||
save_lora,
|
||||
)
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.encoder import EmbeddingEncoder
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
__all__ = [
|
||||
# Modules
|
||||
"Linear",
|
||||
"RMSNorm",
|
||||
"MLP",
|
||||
"GQA",
|
||||
"DecoderBlock",
|
||||
# Models
|
||||
"AutoRegressiveLM",
|
||||
"EmbeddingEncoder",
|
||||
"AutoModel",
|
||||
# LoRA
|
||||
"LoRAConfig",
|
||||
"inject_lora",
|
||||
"merge_lora",
|
||||
"save_lora",
|
||||
"load_lora",
|
||||
]
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
AutoModel base class for model loading and saving.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.serialization import load_model_config, load_model_weights, save_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _disable_random_init(enable: bool = True):
|
||||
if not enable:
|
||||
yield
|
||||
return
|
||||
|
||||
names = (
|
||||
"xavier_normal_",
|
||||
"xavier_uniform_",
|
||||
"kaiming_normal_",
|
||||
"kaiming_uniform_",
|
||||
"zeros_",
|
||||
"ones_",
|
||||
"constant_",
|
||||
"normal_",
|
||||
"uniform_",
|
||||
)
|
||||
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
||||
for n in orig:
|
||||
setattr(nn.init, n, lambda *a, **kw: None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for n, fn in orig.items():
|
||||
setattr(nn.init, n, fn)
|
||||
|
||||
|
||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||
"""
|
||||
Autoregressive language model base class.
|
||||
Provides model loading/saving, registration, and generation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BaseModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
path: Union[str, Path],
|
||||
disable_random_init: bool = True,
|
||||
strict: bool = True,
|
||||
) -> nn.Module:
|
||||
|
||||
model_path = Path(path)
|
||||
|
||||
config_path = model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
raw = load_model_config(str(model_path))
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
model = actual_cls(config)
|
||||
|
||||
weights_path = model_path / "model.safetensors"
|
||||
if weights_path.exists():
|
||||
state_dict = load_model_weights(str(model_path))
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, Path],
|
||||
):
|
||||
save_model(
|
||||
config=self.config.to_dict(),
|
||||
state_dict=self.state_dict(),
|
||||
save_directory=str(save_directory),
|
||||
)
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
"""Move model to device/dtype."""
|
||||
return super().to(*args, **kwargs)
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import (
|
||||
RotaryEmbedding,
|
||||
apply_rotary_emb,
|
||||
get_rotary_emb,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Linear",
|
||||
"RMSNorm",
|
||||
"MLP",
|
||||
"Embedding",
|
||||
"GQA",
|
||||
"MLA",
|
||||
"DecoderBlock",
|
||||
"RotaryEmbedding",
|
||||
"apply_rotary_emb",
|
||||
"get_rotary_emb",
|
||||
"repeat_kv",
|
||||
]
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import apply_rotary_emb
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
bs, slen, n_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class AttnFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
||||
return super().create(attn_type, **kwargs)
|
||||
|
||||
|
||||
@AttnFactory.register("gqa")
|
||||
class GQA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
use_qk_norm: bool,
|
||||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % n_heads == 0
|
||||
assert n_heads % n_kv_heads == 0
|
||||
|
||||
self.head_dim = dim // n_heads
|
||||
self.layer_id = layer_id
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.n_rep = n_heads // n_kv_heads
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.use_gated_attention = use_gated_attention
|
||||
|
||||
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
||||
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.o_proj = Linear(dim, dim)
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
|
||||
if self.use_gated_attention:
|
||||
self.gate = Linear(dim, dim)
|
||||
|
||||
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
is_causal = attn_mask is None
|
||||
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
||||
|
||||
if self.use_qk_norm:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
sdqa_out = (
|
||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
if self.use_gated_attention:
|
||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||
|
||||
out = self.o_proj(sdqa_out)
|
||||
return out
|
||||
|
||||
|
||||
@AttnFactory.register("mla")
|
||||
class MLA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
norm_eps: float,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.layer_id = layer_id
|
||||
self.n_rep = n_heads // n_kv_heads
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.use_gated_attention = use_gated_attention
|
||||
|
||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||
|
||||
self.kv_b_proj = Linear(
|
||||
kv_lora_rank,
|
||||
n_kv_heads * (2 * self.head_dim),
|
||||
)
|
||||
|
||||
self.o_proj = Linear(dim, dim, bias=False)
|
||||
|
||||
if use_gated_attention:
|
||||
self.gate = Linear(dim, dim, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = attn_mask is None
|
||||
|
||||
q = self.q_proj(x)
|
||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||
|
||||
kv_compressed = self.kv_a_proj(x)
|
||||
kv_compressed = self.kv_norm(kv_compressed)
|
||||
|
||||
kv = self.kv_b_proj(kv_compressed)
|
||||
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
||||
|
||||
k_nope, k_rope, v = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
|
||||
)
|
||||
|
||||
q_nope, q_rope = (
|
||||
q[..., : self.qk_nope_head_dim],
|
||||
q[..., self.qk_nope_head_dim :],
|
||||
)
|
||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||
|
||||
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||
|
||||
if self.use_qk_norm:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, is_causal=is_causal
|
||||
)
|
||||
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||
|
||||
if self.use_gated_attention:
|
||||
attn_out = attn_out * F.sigmoid(self.gate(x))
|
||||
|
||||
out = self.o_proj(attn_out)
|
||||
return out
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.attention import AttnFactory
|
||||
from astrai.model.components.mlp import FFNFactory
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
attn_type: str = "gqa",
|
||||
ffn_type: str = "mlp",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = AttnFactory.create(
|
||||
attn_type,
|
||||
dim=dim,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
use_qk_norm=use_qk_norm,
|
||||
norm_eps=norm_eps,
|
||||
use_gated_attention=use_gated_attention,
|
||||
layer_id=layer_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
)
|
||||
x = attn_output + x
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / (fan_in**0.5)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.serialization import (
|
||||
load_json,
|
||||
load_safetensors,
|
||||
save_json,
|
||||
save_safetensors,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
||||
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
r: int = 16
|
||||
alpha: int = 32
|
||||
target_modules: tuple = ("q_proj", "v_proj")
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
||||
super().__init__()
|
||||
self.register_parameter("weight", base.weight)
|
||||
self.weight.requires_grad_(False)
|
||||
self.bias = base.bias
|
||||
if self.bias is not None:
|
||||
self.bias.requires_grad_(False)
|
||||
|
||||
self.r = r
|
||||
self.scaling = alpha / r
|
||||
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
||||
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
||||
self._merged = False
|
||||
|
||||
def forward(self, x):
|
||||
out = F.linear(x, self.weight, self.bias)
|
||||
if not self._merged:
|
||||
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
||||
return out
|
||||
|
||||
def merge(self):
|
||||
if self._merged:
|
||||
return
|
||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||
self._merged = True
|
||||
del self.lora_A
|
||||
del self.lora_B
|
||||
|
||||
|
||||
def _collect_lora_info(model: nn.Module) -> dict:
|
||||
names = {}
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, Linear):
|
||||
_, _, child = n.rpartition(".")
|
||||
names.setdefault(child, []).append(n)
|
||||
return names
|
||||
|
||||
|
||||
def _get_lora_count(model: nn.Module) -> int:
|
||||
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
||||
|
||||
|
||||
def inject_lora(
|
||||
model: nn.Module,
|
||||
r: int = 16,
|
||||
alpha: int = 32,
|
||||
target_modules: Optional[Set[str]] = None,
|
||||
) -> LoRAConfig:
|
||||
if target_modules is None:
|
||||
target_modules = TARGET_MODULES_ATTN
|
||||
|
||||
available = _collect_lora_info(model)
|
||||
injected = 0
|
||||
|
||||
for name, module in list(model.named_modules()):
|
||||
if not isinstance(module, Linear):
|
||||
continue
|
||||
parent_name, _, child_name = name.rpartition(".")
|
||||
if child_name not in target_modules:
|
||||
continue
|
||||
parent = model.get_submodule(parent_name) if parent_name else model
|
||||
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
||||
injected += 1
|
||||
|
||||
if injected == 0:
|
||||
logger.warning(
|
||||
"No LoRA layers injected. Available Linear child names: %s. "
|
||||
"target_modules: %s. Check model type and target_modules.",
|
||||
sorted(available),
|
||||
sorted(target_modules),
|
||||
)
|
||||
else:
|
||||
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
||||
|
||||
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
||||
|
||||
|
||||
def merge_lora(model: nn.Module):
|
||||
n = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
module.merge()
|
||||
n += 1
|
||||
if n == 0:
|
||||
logger.warning("No LoRA layers to merge.")
|
||||
else:
|
||||
logger.info("Merged %d LoRA layers", n)
|
||||
|
||||
|
||||
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
||||
lora_sd = {
|
||||
k: v
|
||||
for k, v in model.state_dict().items()
|
||||
if k.endswith((".lora_A", ".lora_B"))
|
||||
}
|
||||
if not lora_sd:
|
||||
raise RuntimeError(
|
||||
"No LoRA parameters found in model. "
|
||||
"The model may not have been injected or was already merged."
|
||||
)
|
||||
|
||||
path = Path(save_dir)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
||||
save_json(asdict(config), path / "adapter_config.json")
|
||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||
|
||||
|
||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||
path = Path(load_dir)
|
||||
raw = load_json(path / "adapter_config.json")
|
||||
config = LoRAConfig(
|
||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||
)
|
||||
|
||||
existing = _get_lora_count(model)
|
||||
if existing > 0:
|
||||
logger.warning(
|
||||
"Model already has %d LoRA layers. Skipping injection, "
|
||||
"loading weights onto existing layers only.",
|
||||
existing,
|
||||
)
|
||||
else:
|
||||
inject_lora(
|
||||
model,
|
||||
r=config.r,
|
||||
alpha=config.alpha,
|
||||
target_modules=set(config.target_modules),
|
||||
)
|
||||
|
||||
weights = load_safetensors(path / "adapter_model.safetensors")
|
||||
try:
|
||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
if "size mismatch" in msg:
|
||||
raise RuntimeError(
|
||||
f"LoRA weight shapes do not match the model. "
|
||||
f"The adapter config (r={config.r}) may not match the injected layers. "
|
||||
f"Original error: {msg}"
|
||||
) from e
|
||||
raise
|
||||
|
||||
injected = _get_lora_count(model)
|
||||
if injected == 0:
|
||||
raise RuntimeError(
|
||||
"No LoRA layers found after loading. "
|
||||
"Inject LoRA before calling load_lora, or check the adapter config."
|
||||
)
|
||||
|
||||
if missing:
|
||||
lora_missing = [k for k in missing if "lora" in k]
|
||||
if lora_missing:
|
||||
raise RuntimeError(
|
||||
f"LoRA weight keys not found in model: {lora_missing}. "
|
||||
f"The adapter config (r={config.r}) may not match the model."
|
||||
)
|
||||
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
||||
if unexpected:
|
||||
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
||||
|
||||
logger.info("LoRA adapter loaded from %s", load_dir)
|
||||
return config
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.model.components.linear import Linear
|
||||
|
||||
|
||||
class FFNFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
||||
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
||||
|
||||
|
||||
@FFNFactory.register("mlp")
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_ffn: int):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_ffn)
|
||||
self.gate = Linear(dim, dim_ffn)
|
||||
self.down = Linear(dim_ffn, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
|
||||
@FFNFactory.register("moe")
|
||||
class DeepSeekMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_ffn: int,
|
||||
n_routed_experts: int,
|
||||
n_shared_experts: int = 1,
|
||||
n_activated_experts: int = 2,
|
||||
topk_method: str = "greedy",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_activated_experts = n_activated_experts
|
||||
self.topk_method = topk_method
|
||||
|
||||
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||
|
||||
self.shared_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
||||
)
|
||||
self.routed_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
bsz, seq_len, dim = x.shape
|
||||
x_flat = x.view(-1, dim)
|
||||
|
||||
shared_out = self._shared_forward(x_flat)
|
||||
routed_out = self._routed_forward(x_flat)
|
||||
|
||||
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
||||
return out
|
||||
|
||||
def _shared_forward(self, x: Tensor) -> Tensor:
|
||||
if self.n_shared_experts == 0:
|
||||
return torch.zeros_like(x)
|
||||
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
||||
|
||||
def _routed_forward(self, x: Tensor) -> Tensor:
|
||||
N, D = x.shape
|
||||
K = self.n_activated_experts
|
||||
|
||||
router_logits = self.router(x)
|
||||
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
||||
|
||||
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
||||
for expert_idx in range(self.n_routed_experts):
|
||||
expert_mask = topk_indices == expert_idx
|
||||
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
||||
if token_idx.numel() == 0:
|
||||
continue
|
||||
expert_input = x[token_idx]
|
||||
expert_output = self.routed_experts[expert_idx](expert_input)
|
||||
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
||||
output.index_add_(0, token_idx, expert_output * weights)
|
||||
|
||||
return output
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.normalized_shape = (dim,)
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_rotary_emb(
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def ntk_base(base: float, dim: int, factor: float) -> float:
|
||||
return base * (factor ** (dim / (dim - 2)))
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
x_complex = torch.view_as_complex(x_)
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_rotated = x_complex * freqs_cis
|
||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
rope_scaling: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
if rope_scaling is not None:
|
||||
scaling_type = rope_scaling.get("type", "ntk")
|
||||
factor = rope_scaling.get("factor", 1.0)
|
||||
if scaling_type == "ntk":
|
||||
self.base = ntk_base(base, dim, factor)
|
||||
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||
freqs_cis = torch.view_as_real(rotary_emb)
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||
return torch.view_as_complex(position_freq_cis)
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
from typing import Any, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import EncoderConfig
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import RotaryEmbedding
|
||||
from astrai.model.transformer import process_attention_mask
|
||||
|
||||
|
||||
@AutoModel.register("embedding")
|
||||
class EmbeddingEncoder(AutoModel):
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
rope_dim = config.dim // config.n_heads
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||
)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderBlock(
|
||||
config.dim,
|
||||
config.n_heads,
|
||||
config.dim_ffn,
|
||||
config.n_kv_heads,
|
||||
config.norm_eps,
|
||||
config.use_qk_norm,
|
||||
config.use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
self.pooling_type = config.pooling_type or "mean"
|
||||
self.normalize_embeddings = config.normalize_embeddings or False
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if hasattr(module, "reset_parameters"):
|
||||
module.reset_parameters()
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||
state_dict = dict(state_dict)
|
||||
state_dict.pop("lm_head.weight", None)
|
||||
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
B, S = input_ids.shape
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
|
||||
if self.pooling_type == "cls":
|
||||
pooled = hidden_states[:, 0]
|
||||
elif self.pooling_type == "last":
|
||||
if input_mask is not None:
|
||||
lengths = input_mask.sum(dim=1) - 1
|
||||
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
|
||||
else:
|
||||
pooled = hidden_states[:, -1]
|
||||
else:
|
||||
if input_mask is not None:
|
||||
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
||||
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
|
||||
min=1.0
|
||||
)
|
||||
else:
|
||||
pooled = hidden_states.mean(dim=1)
|
||||
|
||||
if self.normalize_embeddings:
|
||||
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
|
||||
|
||||
return pooled
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import RotaryEmbedding
|
||||
|
||||
|
||||
def process_attention_mask(
|
||||
input_tensor: Tensor,
|
||||
position_ids: Optional[Tensor],
|
||||
input_mask: Optional[Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
) -> Optional[Tensor]:
|
||||
if position_ids is None:
|
||||
return None
|
||||
if input_mask is not None and input_mask.dim() > 2:
|
||||
return input_mask
|
||||
|
||||
device = input_tensor.device
|
||||
dtype = input_tensor.dtype
|
||||
B, S = input_tensor.size()[:2]
|
||||
T = position_ids.max().item() + 1
|
||||
|
||||
if input_mask is None:
|
||||
if position_ids.min().item() == 0 and is_causal:
|
||||
return None
|
||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||
else:
|
||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||
|
||||
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
||||
if is_causal:
|
||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||
|
||||
return torch.full(
|
||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||
|
||||
|
||||
@AutoModel.register("autoregressive_lm")
|
||||
class AutoRegressiveLM(AutoModel):
|
||||
"""Autoregressive language model with paged KV cache."""
|
||||
|
||||
def __init__(self, config: AutoRegressiveLMConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
rope_dim = (
|
||||
config.qk_rope_head_dim
|
||||
if config.attn_type == "mla"
|
||||
else config.dim // config.n_heads
|
||||
)
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||
)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderBlock(
|
||||
config.dim,
|
||||
config.n_heads,
|
||||
config.dim_ffn,
|
||||
config.n_kv_heads,
|
||||
config.norm_eps,
|
||||
config.use_qk_norm,
|
||||
config.use_gated_attention,
|
||||
layer_id,
|
||||
attn_type=config.attn_type,
|
||||
ffn_type=config.ffn_type,
|
||||
n_routed_experts=config.n_routed_experts,
|
||||
n_shared_experts=config.n_shared_experts,
|
||||
n_activated_experts=config.n_activated_experts,
|
||||
topk_method=config.topk_method,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||
|
||||
if self.config.tie_weight is True:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if hasattr(module, "reset_parameters"):
|
||||
module.reset_parameters()
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||
lm_head_key = "lm_head.weight"
|
||||
embed_key = "embed_tokens.weight"
|
||||
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if self.config.tie_weight is True:
|
||||
# same tensor for embed and lm_head
|
||||
if embed_key in state_dict:
|
||||
state_dict[lm_head_key] = state_dict[embed_key]
|
||||
else:
|
||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||
# clone to avoid sharing gradients
|
||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||
|
||||
return super().load_state_dict(state_dict, strict, assign)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
state_dict = super().state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
|
||||
if self.config.tie_weight is True:
|
||||
lm_head_key = prefix + "lm_head.weight"
|
||||
if lm_head_key in state_dict:
|
||||
del state_dict[lm_head_key]
|
||||
|
||||
return state_dict
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Dict[str, Tensor]:
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return {"logits": logits, "hidden_states": hidden_states}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
from astrai.parallel.executor import (
|
||||
AccumOptimizer,
|
||||
AccumScheduler,
|
||||
BaseExecutor,
|
||||
DDPExecutor,
|
||||
ExecutorFactory,
|
||||
FSDPExecutor,
|
||||
GradientState,
|
||||
NoneExecutor,
|
||||
)
|
||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||
from astrai.parallel.setup import (
|
||||
get_current_device,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
only_on_rank,
|
||||
setup_parallel,
|
||||
spawn_parallel_fn,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_world_size",
|
||||
"get_rank",
|
||||
"get_current_device",
|
||||
"only_on_rank",
|
||||
"setup_parallel",
|
||||
"spawn_parallel_fn",
|
||||
"RowParallelLinear",
|
||||
"ColumnParallelLinear",
|
||||
"ExecutorFactory",
|
||||
"BaseExecutor",
|
||||
"GradientState",
|
||||
"AccumOptimizer",
|
||||
"AccumScheduler",
|
||||
"NoneExecutor",
|
||||
"DDPExecutor",
|
||||
"FSDPExecutor",
|
||||
]
|
||||
|
|
@ -0,0 +1,271 @@
|
|||
"""Unified training executor — parallel strategy + gradient accumulation."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel.setup import get_rank, get_world_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GradientState:
|
||||
def __init__(self, grad_accum_steps: int = 1):
|
||||
self.num_steps = max(grad_accum_steps, 1)
|
||||
self._step: int = 0
|
||||
self._sync_gradients: bool = True
|
||||
|
||||
@property
|
||||
def sync_gradients(self) -> bool:
|
||||
return self._sync_gradients
|
||||
|
||||
def _do_sync(self):
|
||||
self._step += 1
|
||||
self._sync_gradients = self._step % self.num_steps == 0
|
||||
|
||||
|
||||
class AccumOptimizer:
|
||||
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
||||
self.optimizer = optimizer
|
||||
self.gradient_state = gradient_state
|
||||
|
||||
def step(self, closure=None):
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.optimizer.step(closure)
|
||||
|
||||
def zero_grad(self):
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def state_dict(self):
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
def load_state_dict(self, d):
|
||||
self.optimizer.load_state_dict(d)
|
||||
|
||||
|
||||
class AccumScheduler:
|
||||
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
||||
self.scheduler = scheduler
|
||||
self.gradient_state = gradient_state
|
||||
|
||||
def step(self):
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.scheduler.step()
|
||||
|
||||
def state_dict(self):
|
||||
return self.scheduler.state_dict()
|
||||
|
||||
def load_state_dict(self, d):
|
||||
self.scheduler.load_state_dict(d)
|
||||
|
||||
def get_last_lr(self):
|
||||
return self.scheduler.get_last_lr()
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
def __init__(self, grad_accum_steps: int = 1):
|
||||
self.gradient_state = GradientState(grad_accum_steps)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[
|
||||
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
|
||||
]:
|
||||
model = self._prepare_model(model)
|
||||
if optimizer is not None:
|
||||
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
||||
if scheduler is not None:
|
||||
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
||||
return model, optimizer, dataloader, scheduler
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@contextmanager
|
||||
def accumulate(self, model: nn.Module):
|
||||
self.gradient_state._do_sync()
|
||||
if not self.gradient_state.sync_gradients:
|
||||
with self._no_sync(model):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss.backward()
|
||||
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
return model.state_dict()
|
||||
|
||||
@property
|
||||
def use_distributed(self) -> bool:
|
||||
return get_world_size() > 1
|
||||
|
||||
@property
|
||||
def sync_gradients(self) -> bool:
|
||||
return self.gradient_state.sync_gradients
|
||||
|
||||
@property
|
||||
def grad_accum_steps(self) -> int:
|
||||
return self.gradient_state.num_steps
|
||||
|
||||
|
||||
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
||||
pass
|
||||
|
||||
|
||||
@ExecutorFactory.register("none")
|
||||
class NoneExecutor(BaseExecutor):
|
||||
pass
|
||||
|
||||
|
||||
@ExecutorFactory.register("ddp")
|
||||
class DDPExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
grad_accum_steps: int = 1,
|
||||
dim: int = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
init_sync: bool = True,
|
||||
process_group=None,
|
||||
bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
delay_all_reduce_named_params=None,
|
||||
param_to_hook_all_reduce=None,
|
||||
mixed_precision=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||
self._ddp_kwargs = dict(
|
||||
dim=dim,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
init_sync=init_sync,
|
||||
process_group=process_group,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
delay_all_reduce_named_params=delay_all_reduce_named_params,
|
||||
param_to_hook_all_reduce=param_to_hook_all_reduce,
|
||||
mixed_precision=mixed_precision,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
if not self.use_distributed:
|
||||
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
||||
return model
|
||||
local_rank = get_rank()
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
**self._ddp_kwargs,
|
||||
)
|
||||
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
if isinstance(model, DDP):
|
||||
return model.no_sync()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
if isinstance(model, DDP):
|
||||
return model.module.state_dict()
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
@ExecutorFactory.register("fsdp")
|
||||
class FSDPExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
grad_accum_steps: int = 1,
|
||||
process_group=None,
|
||||
sharding_strategy=None,
|
||||
cpu_offload=None,
|
||||
auto_wrap_policy=None,
|
||||
backward_prefetch=None,
|
||||
mixed_precision=None,
|
||||
ignored_modules=None,
|
||||
param_init_fn=None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = True,
|
||||
ignored_states=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||
self._fsdp_kwargs = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=True,
|
||||
ignored_states=ignored_states,
|
||||
device_mesh=device_mesh,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
self._original_model: Optional[nn.Module] = None
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
if not self.use_distributed:
|
||||
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
||||
return model
|
||||
self._original_model = model
|
||||
device_id = torch.device("cuda", get_rank())
|
||||
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
||||
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
if isinstance(model, FSDP):
|
||||
return model.no_sync()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
if isinstance(model, FSDP) and self.use_distributed:
|
||||
with FSDP.state_dict_type(
|
||||
model,
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||
):
|
||||
return model.state_dict()
|
||||
|
||||
return model.state_dict()
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ParallelModel(nn.Module):
|
||||
def __init__(self, process_group: dist.ProcessGroup):
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.rank = dist.get_rank(self.process_group)
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
|
||||
class RowParallelLinear(ParallelModel):
|
||||
def __init__(
|
||||
self,
|
||||
process_group: dist.ProcessGroup,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
reduce_results: bool = True,
|
||||
):
|
||||
super().__init__(process_group)
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.in_features_per_rank = in_features // self.world_size
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
if in_features % self.world_size != 0:
|
||||
raise ValueError(
|
||||
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
||||
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
output = F.linear(input, self.weight)
|
||||
|
||||
if self.reduce_results:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
|
||||
if self.bias is not None:
|
||||
output += self.bias
|
||||
|
||||
return output
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||
full_weight = state_dict.get("weight")
|
||||
full_bias = state_dict.get("bias")
|
||||
|
||||
start_idx = self.rank * self.in_features_per_rank
|
||||
end_idx = start_idx + self.in_features_per_rank
|
||||
weight_slice = full_weight[:, start_idx:end_idx]
|
||||
self.weight.data.copy_(weight_slice)
|
||||
|
||||
if self.bias is not None:
|
||||
self.bias.data.copy_(full_bias)
|
||||
|
||||
|
||||
class ColumnParallelLinear(ParallelModel):
|
||||
def __init__(
|
||||
self,
|
||||
process_group: dist.ProcessGroup,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
gather_results: bool = True,
|
||||
):
|
||||
super().__init__(process_group)
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.out_features_per_rank = out_features // self.world_size
|
||||
self.gather_results = gather_results
|
||||
|
||||
if out_features % self.world_size != 0:
|
||||
raise ValueError(
|
||||
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.out_features_per_rank, self.in_features)
|
||||
)
|
||||
self.bias = (
|
||||
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
output = F.linear(input, self.weight, self.bias)
|
||||
|
||||
if self.gather_results:
|
||||
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
|
||||
dist.all_gather(output_list, output, group=self.process_group)
|
||||
output = torch.cat(output_list, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||
full_weight = state_dict.get("weight")
|
||||
full_bias = state_dict.get("bias")
|
||||
|
||||
start_idx = self.rank * self.out_features_per_rank
|
||||
end_idx = start_idx + self.out_features_per_rank
|
||||
weight_slice = full_weight[start_idx:end_idx, :]
|
||||
self.weight.data.copy_(weight_slice)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_slice = full_bias[start_idx:end_idx]
|
||||
self.bias.data.copy_(bias_slice)
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def get_current_device():
|
||||
return os.environ["LOCAL_DEVICE"]
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_parallel(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
):
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
yield dist.group.WORLD
|
||||
return
|
||||
|
||||
if world_size <= 1:
|
||||
yield None
|
||||
return
|
||||
|
||||
device_id = torch.device(device_type, rank)
|
||||
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = master_port
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
dist.init_process_group(
|
||||
rank=rank, world_size=world_size, backend=backend, device_id=device_id
|
||||
)
|
||||
|
||||
try:
|
||||
if backend == "nccl" and torch.cuda.is_available():
|
||||
torch.cuda.set_device(device_id)
|
||||
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
torch.xpu.set_device(device_id)
|
||||
|
||||
yield dist.group.WORLD
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def only_on_rank(rank, sync=False):
|
||||
"""
|
||||
decorator to run a function only on a specific rank.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
ret_args = None
|
||||
if get_rank() == rank:
|
||||
ret_args = func(*args, **kwargs)
|
||||
|
||||
if sync and dist.is_available() and dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
return ret_args
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def wrapper_spawn_func(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
device_type: str,
|
||||
func: Callable,
|
||||
kwargs: dict,
|
||||
):
|
||||
try:
|
||||
with setup_parallel(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
device_type=device_type,
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in rank {rank}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def spawn_parallel_fn(
|
||||
func: Callable,
|
||||
world_size: int,
|
||||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
start_method: str = "spawn",
|
||||
**kwargs,
|
||||
):
|
||||
# clear environment variables
|
||||
for key in [
|
||||
"MASTER_ADDR",
|
||||
"MASTER_PORT",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"LOCAL_RANK",
|
||||
"LOCAL_DEVICE",
|
||||
]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
if world_size == 1:
|
||||
device_id = torch.device(device_type, 0)
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
wrapper_spawn_func_args = (
|
||||
world_size,
|
||||
backend,
|
||||
master_addr,
|
||||
master_port,
|
||||
device_type,
|
||||
func,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
mp.start_processes(
|
||||
wrapper_spawn_func,
|
||||
args=wrapper_spawn_func_args,
|
||||
nprocs=world_size,
|
||||
start_method=start_method,
|
||||
join=True,
|
||||
)
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from astrai.preprocessing.builder import (
|
||||
BaseMaskBuilder,
|
||||
MaskBuilderFactory,
|
||||
SectionedMaskBuilder,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||
|
||||
__all__ = [
|
||||
"BaseMaskBuilder",
|
||||
"MaskBuilderFactory",
|
||||
"SectionedMaskBuilder",
|
||||
"Pipeline",
|
||||
"filter_by_length",
|
||||
]
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
"""Mask building strategies for preprocessing pipeline.
|
||||
|
||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||
via declarative ``input.sections`` config.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseMaskBuilder(ABC):
|
||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
||||
|
||||
Returns ``None`` to skip the item entirely.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseMaskBuilder):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||
)
|
||||
|
||||
|
||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||
if not domain_key:
|
||||
return "__default__"
|
||||
val = item.get(domain_key, "__default__")
|
||||
return val if isinstance(val, str) else "__default__"
|
||||
|
||||
|
||||
def _resolve_action(action: str, role: str, config) -> str:
|
||||
"""Resolve action to "train" or "mask".
|
||||
|
||||
- ``"train"`` / ``"mask"`` → literal
|
||||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||||
"""
|
||||
if action == "$role":
|
||||
return config.mask.get(role, config.mask_default)
|
||||
return action
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("sectioned")
|
||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||
"""Config-driven builder: iterates over ``input.sections`` in order.
|
||||
|
||||
Each section specifies a JSONL field + mask action.
|
||||
|
||||
Section spec::
|
||||
|
||||
{
|
||||
"field": "messages", # JSONL key
|
||||
"action": "$role", # "train" | "mask" | "$role"
|
||||
"template": true, # apply chat_template per message (optional)
|
||||
"add_special_tokens": false # override encode flag (optional)
|
||||
}
|
||||
|
||||
Example configs::
|
||||
|
||||
# Chat
|
||||
{"input": {"sections": [
|
||||
{"field": "messages", "action": "$role", "template": true}
|
||||
]}}
|
||||
|
||||
# Instruction
|
||||
{"input": {"sections": [
|
||||
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
||||
{"field": "response", "action": "train"}
|
||||
]}}
|
||||
|
||||
# Text
|
||||
{"input": {"sections": [
|
||||
{"field": "text", "action": "train"}
|
||||
]}}
|
||||
"""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
sections = config.input.sections
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
all_ids: list[int] = []
|
||||
loss_mask: list[int] = []
|
||||
|
||||
has_template = any(s.get("template") for s in sections)
|
||||
is_text_config = not has_template and all(
|
||||
s["action"] == "train" for s in sections
|
||||
)
|
||||
|
||||
if has_template and tokenizer.bos_token_id is not None:
|
||||
all_ids.append(tokenizer.bos_token_id)
|
||||
loss_mask.append(0)
|
||||
|
||||
first_section = True
|
||||
for sec in sections:
|
||||
field = sec["field"]
|
||||
action = sec["action"]
|
||||
use_template = sec.get("template", False)
|
||||
add_special = sec.get(
|
||||
"add_special_tokens", not use_template and first_section
|
||||
)
|
||||
|
||||
if use_template:
|
||||
messages = item.get(field)
|
||||
if not isinstance(messages, list) or not messages:
|
||||
continue
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
act = _resolve_action(action, role, config)
|
||||
rendered = tokenizer.apply_chat_template(
|
||||
[msg], tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||
all_ids.extend(ids)
|
||||
val = 1 if act == "train" else 0
|
||||
loss_mask.extend([val] * len(ids))
|
||||
else:
|
||||
text = str(item.get(field, ""))
|
||||
if not text.strip():
|
||||
continue
|
||||
if is_text_config:
|
||||
pp = config.preprocessing
|
||||
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||||
continue
|
||||
if len(text) > pp.max_chars:
|
||||
continue
|
||||
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||||
all_ids.extend(ids)
|
||||
val = 1 if action == "train" else 0
|
||||
loss_mask.extend([val] * len(ids))
|
||||
|
||||
first_section = False
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
all_ids = all_ids[:max_len]
|
||||
loss_mask = loss_mask[: len(all_ids)]
|
||||
|
||||
if not all_ids:
|
||||
return None
|
||||
|
||||
if has_template and len(all_ids) <= 1:
|
||||
return None
|
||||
|
||||
result: dict = {
|
||||
"sequence": all_ids,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
if not all(m == 1 for m in loss_mask):
|
||||
result["loss_mask"] = loss_mask
|
||||
return result
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
"""Config-driven JSONL preprocessing pipeline.
|
||||
|
||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||
sharding and flush to ``.h5`` / ``.bin`` storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.preprocessing.builder import SectionedMaskBuilder
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||
"bool": torch.bool,
|
||||
"uint8": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"int16": torch.int16,
|
||||
"int32": torch.int32,
|
||||
"int64": torch.int64,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
}
|
||||
|
||||
|
||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||
return min_len <= len(text) <= max_len
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||
|
||||
Usage::
|
||||
|
||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
input_paths: list[str],
|
||||
output_dir: str,
|
||||
tokenizer_path: str,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.config = config
|
||||
self.paths = input_paths
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
self.mask_builder = SectionedMaskBuilder()
|
||||
|
||||
def transform(self, item: dict) -> Optional[dict]:
|
||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||
|
||||
def run(self):
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||
domains: dict = defaultdict(lambda: defaultdict(list))
|
||||
total_tokens = 0
|
||||
shard_idx: dict[str, int] = defaultdict(int)
|
||||
count = 0
|
||||
|
||||
pp = self.config.preprocessing
|
||||
|
||||
for item in tqdm.tqdm(
|
||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
||||
):
|
||||
if pp.max_items and count >= pp.max_items:
|
||||
break
|
||||
|
||||
result = self.transform(item)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
ids = result.pop("sequence")
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
domain = result.pop("domain", "__default__")
|
||||
result["sequence"] = ids
|
||||
|
||||
bucket = domains[domain]
|
||||
for key in list(bucket.keys()):
|
||||
if key not in result:
|
||||
bucket[key].append([1] * len(ids))
|
||||
for key, val in result.items():
|
||||
bucket[key].append(val)
|
||||
|
||||
count += 1
|
||||
total_tokens += len(ids)
|
||||
|
||||
if total_tokens >= self.config.output.max_tokens_per_shard:
|
||||
self._flush(domains, shard_idx)
|
||||
domains.clear()
|
||||
total_tokens = 0
|
||||
|
||||
if total_tokens > 0:
|
||||
self._flush(domains, shard_idx)
|
||||
|
||||
print(f"Done. {count} documents tokenized.")
|
||||
|
||||
def _iter_items(self):
|
||||
for path in self.paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
yield json.loads(line)
|
||||
|
||||
def _flush(self, domains, shard_idx):
|
||||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
dt = _STR_TO_DTYPE.get(
|
||||
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||
)
|
||||
tensors[key] = [
|
||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||
]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
fmt = self.config.output.storage_format
|
||||
if fmt == "bin":
|
||||
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
|
||||
else:
|
||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||
shard_idx[domain] = idx + 1
|
||||
tqdm.tqdm.write(
|
||||
f" saved {domain}/shard_{idx:04d} "
|
||||
f"({tensors['sequence'][0].numel():,} tokens)"
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OptimizerProtocol(Protocol):
|
||||
def step(self, closure=None): ...
|
||||
def zero_grad(self): ...
|
||||
@property
|
||||
def param_groups(self) -> Any: ...
|
||||
def state_dict(self) -> dict: ...
|
||||
def load_state_dict(self, d: dict): ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SchedulerProtocol(Protocol):
|
||||
def step(self): ...
|
||||
def state_dict(self) -> dict: ...
|
||||
def load_state_dict(self, d: dict): ...
|
||||
def get_last_lr(self): ...
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
import io
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
_META_FILE = "meta.json"
|
||||
_CONFIG_FILE = "config.json"
|
||||
_WEIGHTS_FILE = "model.safetensors"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
||||
st.save_file(state_dict, str(path))
|
||||
|
||||
|
||||
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return st.load_file(str(path))
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = st.load_file(str(path))
|
||||
else:
|
||||
state_dict = {}
|
||||
tmp = [state_dict]
|
||||
dist.broadcast_object_list(tmp, src=0)
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def save_json(data: dict, path: Union[str, Path]):
|
||||
with open(str(path), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
with open(str(path), "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
with open(str(path), "r") as f:
|
||||
data = json.load(f)
|
||||
else:
|
||||
data = {}
|
||||
tmp = [data]
|
||||
dist.broadcast_object_list(tmp, src=0)
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def save_torch(obj: Any, path: Union[str, Path]):
|
||||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
|
||||
path = Path(path)
|
||||
rank = get_rank()
|
||||
|
||||
if rank == 0:
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
||||
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
||||
else:
|
||||
num_bytes = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
dist.broadcast(num_bytes, src=0)
|
||||
|
||||
if rank != 0:
|
||||
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
||||
|
||||
dist.broadcast(data_tensor, src=0)
|
||||
|
||||
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
||||
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str):
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
path = Path(path)
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return load_safetensors(path)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = load_safetensors(path)
|
||||
specs = [
|
||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||
for k in sorted(state_dict)
|
||||
]
|
||||
else:
|
||||
state_dict = {}
|
||||
specs = []
|
||||
|
||||
specs_list = [specs]
|
||||
dist.broadcast_object_list(specs_list, src=0)
|
||||
specs = specs_list[0]
|
||||
|
||||
for key, shape, dtype_name in specs:
|
||||
dtype = getattr(torch, dtype_name)
|
||||
if rank != 0:
|
||||
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
||||
else:
|
||||
tensor = state_dict[key].contiguous().cpu()
|
||||
dist.broadcast(tensor, src=0)
|
||||
if rank != 0:
|
||||
state_dict[key] = tensor
|
||||
return state_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
state_dict: Dict[str, Any] = field(default_factory=dict)
|
||||
epoch: int = 0
|
||||
iteration: int = 0
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
config: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def save(self, save_dir: str):
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if get_rank() != 0:
|
||||
return
|
||||
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
**self.meta,
|
||||
}
|
||||
save_json(meta, save_path / _META_FILE)
|
||||
save_json(self.config, save_path / _CONFIG_FILE)
|
||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||
for key, value in self.extra.items():
|
||||
save_torch(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = load_json(save_path / _META_FILE, broadcast)
|
||||
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
||||
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
||||
|
||||
extra = {}
|
||||
for f in sorted(save_path.iterdir()):
|
||||
if f.suffix == ".pt":
|
||||
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
epoch=meta.get("epoch", 0),
|
||||
iteration=meta.get("iteration", 0),
|
||||
extra=extra,
|
||||
config=config,
|
||||
)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
__all__ = [
|
||||
"AutoTokenizer",
|
||||
"ChatTemplate",
|
||||
"MessageType",
|
||||
]
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
type MessageType = Dict[str, Any]
|
||||
|
||||
|
||||
class ChatTemplate:
|
||||
"""A chat template with Jinja2 rendering support.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the template.
|
||||
template_str: Jinja2 template string.
|
||||
description: Optional description.
|
||||
default_variables: Optional dictionary of default variable values.
|
||||
special_tokens: Optional dictionary mapping token names to their string values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "",
|
||||
template_str: str = "",
|
||||
description: str = "",
|
||||
default_variables: Optional[Dict[str, Any]] = None,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.template_str = template_str
|
||||
self.description = description
|
||||
self.default_variables = default_variables or {}
|
||||
self.special_tokens = special_tokens or {}
|
||||
self._compiled: Template = Template(template_str)
|
||||
|
||||
@classmethod
|
||||
def from_string(
|
||||
cls,
|
||||
template_str: str,
|
||||
description: str = "",
|
||||
default_variables: Optional[Dict[str, Any]] = None,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
) -> "ChatTemplate":
|
||||
"""Create a ChatTemplate instance directly from a template string."""
|
||||
return cls(
|
||||
name="",
|
||||
template_str=template_str,
|
||||
description=description,
|
||||
default_variables=default_variables,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
messages: List[MessageType],
|
||||
system_prompt: Optional[str] = None,
|
||||
**extra_variables: Any,
|
||||
) -> str:
|
||||
"""Render the template with given messages and variables.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
system_prompt: Optional system prompt string.
|
||||
**extra_variables: Additional variables to pass to the template.
|
||||
These override default_variables and special_tokens.
|
||||
|
||||
Returns:
|
||||
Rendered prompt string.
|
||||
"""
|
||||
# Merge default variables, special tokens, and extra variables
|
||||
variables = {**self.default_variables, **self.special_tokens, **extra_variables}
|
||||
variables["messages"] = messages
|
||||
if system_prompt is not None:
|
||||
variables["system_prompt"] = system_prompt
|
||||
|
||||
return self._compiled.render(**variables)
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
"""
|
||||
Tokenizer module with implementation and auto-loading support.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from astrai.tokenize.chat_template import ChatTemplate
|
||||
|
||||
|
||||
class AutoTokenizer:
|
||||
"""Base tokenizer class with automatic loading support"""
|
||||
|
||||
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Optional[Union[str, Path]] = None,
|
||||
special_token_map: Optional[Dict[str, str]] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
):
|
||||
self._tokenizer: Tokenizer = None
|
||||
self._chat_template: Optional[ChatTemplate] = None
|
||||
self._special_token_map: Optional[Dict] = special_token_map or {}
|
||||
|
||||
if chat_template:
|
||||
self.set_chat_template(chat_template)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
def load(self, path: Union[str, Path]):
|
||||
"""Load tokenizer from directory."""
|
||||
path = Path(path)
|
||||
tokenizer_file = path / "tokenizer.json"
|
||||
config_file = path / "tokenizer_config.json"
|
||||
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
||||
|
||||
if config_file.exists():
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
if "special_tokens" in config:
|
||||
self._special_token_map.update(config["special_tokens"])
|
||||
|
||||
# Load chat template from config
|
||||
if "chat_template" in config:
|
||||
self.set_chat_template(config["chat_template"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
|
||||
"""Load tokenizer from pretrained directory.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If tokenizer.json is missing.
|
||||
RuntimeError: If tokenizer failed to initialize.
|
||||
"""
|
||||
path = Path(path)
|
||||
tokenizer_file = path / "tokenizer.json"
|
||||
if not tokenizer_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Tokenizer file not found: {tokenizer_file}. "
|
||||
"A valid tokenizer.json is required."
|
||||
)
|
||||
instance = cls(path)
|
||||
if instance._tokenizer is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to load tokenizer from {path}. "
|
||||
"The tokenizer.json may be corrupted or incompatible."
|
||||
)
|
||||
return instance
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
"""
|
||||
Save tokenizer to pretrained directory.
|
||||
|
||||
Args:
|
||||
save_path: Path to save the tokenizer
|
||||
"""
|
||||
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError(
|
||||
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||
)
|
||||
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save tokenizer
|
||||
self._tokenizer.save(str(save_path / "tokenizer.json"))
|
||||
|
||||
# Save tokenizer config
|
||||
config = {}
|
||||
if self._special_token_map is not None:
|
||||
config["special_tokens"] = self._special_token_map
|
||||
if self._chat_template is not None:
|
||||
config["chat_template"] = self._chat_template.template_str
|
||||
|
||||
with open(save_path / "tokenizer_config.json", "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
@classmethod
|
||||
def register_tokenizer(cls, name: str, tokenizer_class: type):
|
||||
"""
|
||||
Register a new tokenizer class.
|
||||
|
||||
Args:
|
||||
name: Name to register the tokenizer class under
|
||||
tokenizer_class: The tokenizer class to register
|
||||
"""
|
||||
cls.TOKENIZER_CLASSES[name] = tokenizer_class
|
||||
|
||||
def encode(
|
||||
self,
|
||||
tokens: Union[str, List[str]],
|
||||
out_ids: bool = True,
|
||||
is_pretokenized: bool = False,
|
||||
add_special_tokens: bool = True,
|
||||
) -> List:
|
||||
"""Encode text to tokens or token IDs."""
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError(
|
||||
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||
)
|
||||
|
||||
if isinstance(tokens, str):
|
||||
encoded = self._tokenizer.encode(
|
||||
tokens,
|
||||
is_pretokenized=is_pretokenized,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
return encoded.ids if out_ids else encoded.tokens
|
||||
else:
|
||||
encoded_list = self._tokenizer.encode_batch(
|
||||
tokens,
|
||||
is_pretokenized=is_pretokenized,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
return [
|
||||
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
||||
]
|
||||
|
||||
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
||||
"""Decode token IDs to text."""
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError(
|
||||
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||
)
|
||||
|
||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self._tokenizer is None:
|
||||
return 0
|
||||
return self._tokenizer.get_vocab_size()
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
"""
|
||||
Dynamically intercept special token attribute access.
|
||||
Supports three forms:
|
||||
- tokenizer.bos_token → returns string
|
||||
- tokenizer.bos_token_id → returns corresponding integer ID
|
||||
- tokenizer.stop_ids → returns list of corresponding integer IDs for all special tokens
|
||||
"""
|
||||
# Handle stop_ids - return IDs for all special tokens
|
||||
if key == "stop_ids":
|
||||
stop_ids = []
|
||||
|
||||
if self._tokenizer is None:
|
||||
return stop_ids
|
||||
|
||||
for val in self._special_token_map.values():
|
||||
token_id = self._tokenizer.token_to_id(val)
|
||||
if token_id is not None:
|
||||
stop_ids.append(token_id)
|
||||
|
||||
return stop_ids
|
||||
|
||||
# Handle _id suffix (e.g., bos_token_id -> bos_token)
|
||||
if key.endswith("_id"):
|
||||
base_attr = key[:-3] # Remove "_id"
|
||||
token_str = self._special_token_map.get(base_attr)
|
||||
if token_str is None:
|
||||
return None
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
|
||||
return self._tokenizer.token_to_id(token_str)
|
||||
|
||||
# Handle regular string attributes
|
||||
if key in self._special_token_map:
|
||||
return self._special_token_map.get(key)
|
||||
|
||||
# Other attributes trigger default AttributeError
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self)
|
||||
|
||||
def set_chat_template(self, template: Union[str, ChatTemplate]):
|
||||
"""
|
||||
Set the chat template for the tokenizer.
|
||||
|
||||
Args:
|
||||
template: Either a template name (str) registered in the global registry,
|
||||
or a ChatTemplate instance, or a Jinja2 template string.
|
||||
|
||||
Raises:
|
||||
KeyError: If template name is not registered.
|
||||
"""
|
||||
if isinstance(template, str):
|
||||
self._chat_template = ChatTemplate.from_string(template)
|
||||
elif isinstance(template, ChatTemplate):
|
||||
self._chat_template = template
|
||||
else:
|
||||
raise ValueError("Invalid template type, must be str or ChatTemplate.")
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
system_prompt: Optional[str] = None,
|
||||
tokenize: bool = True,
|
||||
add_generation_prompt: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[str, List[int]]:
|
||||
"""
|
||||
Apply the chat template to messages and optionally tokenize the result.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
system_prompt: Optional system prompt string (auto-converted to first message).
|
||||
tokenize: Whether to return token IDs (True) or raw string (False).
|
||||
add_generation_prompt: Whether to add the generation prompt (default: True).
|
||||
**kwargs: Additional variables to pass to the template.
|
||||
|
||||
Returns:
|
||||
Either the rendered string or list of token IDs.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If chat template is not set.
|
||||
"""
|
||||
if self._chat_template is None:
|
||||
raise RuntimeError(
|
||||
"Chat template not set. Use set_chat_template() to set a template first."
|
||||
)
|
||||
|
||||
# Auto-convert system_prompt to first message if provided
|
||||
if system_prompt:
|
||||
messages = [{"role": "system", "content": system_prompt}] + list(messages)
|
||||
|
||||
# Render the template
|
||||
rendered = self._chat_template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if tokenize:
|
||||
return self.encode(rendered)
|
||||
|
||||
return rendered
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
from astrai.trainer.optim import Muon
|
||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
from astrai.trainer.train_callback import (
|
||||
CallbackFactory,
|
||||
TrainCallback,
|
||||
)
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
__all__ = [
|
||||
# Main trainer
|
||||
"Trainer",
|
||||
# Optimizer
|
||||
"Muon",
|
||||
# Strategy factory
|
||||
"StrategyFactory",
|
||||
"BaseStrategy",
|
||||
# Scheduler factory
|
||||
"SchedulerFactory",
|
||||
"BaseScheduler",
|
||||
# Callback factory
|
||||
"TrainCallback",
|
||||
"CallbackFactory",
|
||||
]
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
from typing import Any, Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _grad_stat(
|
||||
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
|
||||
) -> dict:
|
||||
results = {}
|
||||
for name, param in model.named_parameters():
|
||||
results[name] = default
|
||||
if param.grad is not None:
|
||||
results[name] = fn(param.grad.data)
|
||||
return results
|
||||
|
||||
|
||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0)
|
||||
|
||||
|
||||
def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||
return _grad_stat(model, lambda g: g.std().item(), 0.0)
|
||||
|
||||
|
||||
def grad_max(model: nn.Module) -> Dict[str, float]:
|
||||
return _grad_stat(model, lambda g: g.max().item(), -float("inf"))
|
||||
|
||||
|
||||
def grad_min(model: nn.Module) -> Dict[str, float]:
|
||||
return _grad_stat(model, lambda g: g.min().item(), float("inf"))
|
||||
|
||||
|
||||
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||
return _grad_stat(model, lambda g: g.mean().item(), 0.0)
|
||||
|
||||
|
||||
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0)
|
||||
|
||||
|
||||
def ctx_get_loss(ctx):
|
||||
return ctx.loss
|
||||
|
||||
|
||||
def ctx_get_lr(ctx):
|
||||
return ctx.optimizer.param_groups[-1]["lr"]
|
||||
|
||||
|
||||
def ctx_get_val_loss(ctx):
|
||||
return ctx.val_loss
|
||||
|
||||
|
||||
def ctx_get_grad_norm(ctx):
|
||||
return grad_norm(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_std(ctx):
|
||||
return grad_std(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_max(ctx):
|
||||
return grad_max(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_min(ctx):
|
||||
return grad_min(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_mean(ctx):
|
||||
return grad_mean(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_nan_num(ctx):
|
||||
return grad_nan_num(ctx.model)
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||
assert G.ndim == 2
|
||||
X = G
|
||||
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||
X = X / (X.norm() + 1e-7) * scale
|
||||
if steps == 0:
|
||||
return X
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = A @ X
|
||||
X = a * X + b * B + c * (A @ B)
|
||||
return X
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 2e-3,
|
||||
momentum: float = 0.95,
|
||||
weight_decay: float = 0.0,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
adamw_lr: float = None,
|
||||
adamw_betas: tuple = (0.9, 0.95),
|
||||
adamw_eps: float = 1e-8,
|
||||
adamw_wd: float = 0.0,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
ns_steps=ns_steps,
|
||||
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
||||
adamw_betas=adamw_betas,
|
||||
adamw_eps=adamw_eps,
|
||||
adamw_wd=adamw_wd,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_2d, params_1d = [], []
|
||||
grads_2d, grads_1d = [], []
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("Muon does not support sparse gradients")
|
||||
if p.ndim >= 2:
|
||||
params_2d.append(p)
|
||||
grads_2d.append(p.grad)
|
||||
else:
|
||||
params_1d.append(p)
|
||||
grads_1d.append(p.grad)
|
||||
|
||||
if params_2d:
|
||||
self._muon_update_foreach(params_2d, grads_2d, group)
|
||||
if params_1d:
|
||||
self._adamw_update_foreach(params_1d, grads_1d, group)
|
||||
|
||||
return loss
|
||||
|
||||
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
||||
lr = group["lr"]
|
||||
momentum = group["momentum"]
|
||||
wd = group["weight_decay"]
|
||||
nesterov = group["nesterov"]
|
||||
ns_steps = group["ns_steps"]
|
||||
|
||||
if wd != 0:
|
||||
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
||||
|
||||
if nesterov:
|
||||
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
||||
|
||||
bufs = []
|
||||
for p, grad in zip(params_2d, grads_2d):
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||
bufs.append(state["momentum_buffer"])
|
||||
|
||||
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
||||
|
||||
for p, buf in zip(params_2d, bufs):
|
||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||
p.add_(update, alpha=-lr * scale)
|
||||
|
||||
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
||||
lr = group["adamw_lr"]
|
||||
betas = group["adamw_betas"]
|
||||
eps = group["adamw_eps"]
|
||||
wd = group["adamw_wd"]
|
||||
|
||||
steps: list[int] = []
|
||||
exp_avgs, exp_avg_sqs = [], []
|
||||
has_state = []
|
||||
for p in params_1d:
|
||||
state = self.state[p]
|
||||
if not state:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
has_state.append(False)
|
||||
else:
|
||||
has_state.append(True)
|
||||
state["step"] += 1
|
||||
steps.append(state["step"])
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
|
||||
beta1, beta2 = betas
|
||||
|
||||
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
||||
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
||||
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
||||
|
||||
bias_correction1 = [1 - beta1**s for s in steps]
|
||||
bias_correction2 = [1 - beta2**s for s in steps]
|
||||
|
||||
if wd != 0:
|
||||
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
||||
|
||||
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
||||
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
||||
denom = torch._foreach_sqrt(denom)
|
||||
torch._foreach_add_(denom, eps)
|
||||
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
"""Learning rate scheduler implementations with factory pattern."""
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseScheduler(LRScheduler, ABC):
|
||||
"""Base scheduler class for all other schedulers."""
|
||||
|
||||
def __init__(self, optimizer, last_epoch: int = -1):
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
@abstractmethod
|
||||
def get_lr(self) -> List[float]:
|
||||
"""Calculate the current learning rate."""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return super().state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||
"""Factory class for creating learning rate schedulers.
|
||||
|
||||
Supports decorator-based registration for extensible scheduler types.
|
||||
Also supports creation from ScheduleConfig objects.
|
||||
|
||||
Example usage:
|
||||
@SchedulerFactory.register("custom")
|
||||
class CustomScheduler(BaseScheduler):
|
||||
...
|
||||
|
||||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, optimizer, schedule_type: str = "none", **kwargs
|
||||
) -> "BaseScheduler":
|
||||
"""Create a scheduler instance by type name.
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
schedule_type: Type of scheduler ("cosine", "sgdr")
|
||||
**kwargs: Arguments passed to the scheduler constructor
|
||||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
"""
|
||||
return super().create(schedule_type, optimizer, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered scheduler type names."""
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
# ----------- Scheduler implementations -----------
|
||||
|
||||
|
||||
@SchedulerFactory.register("cosine")
|
||||
class CosineScheduler(BaseScheduler):
|
||||
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
warmup_steps: int,
|
||||
lr_decay_steps: int,
|
||||
min_rate: float = 0.05,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.min_rate = min_rate
|
||||
self.total_steps = warmup_steps + lr_decay_steps
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
# warmup
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
||||
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||
|
||||
# cosine decay
|
||||
decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps
|
||||
decay_progress = min(decay_progress, 1.0)
|
||||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
|
||||
decay_factor = max(self.min_rate, cosine_decay)
|
||||
return [base_lr * decay_factor for base_lr in self.base_lrs]
|
||||
|
||||
def state_dict(self):
|
||||
state = super().state_dict()
|
||||
state.update(
|
||||
{
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"lr_decay_steps": self.lr_decay_steps,
|
||||
"min_rate": self.min_rate,
|
||||
"total_steps": self.total_steps,
|
||||
}
|
||||
)
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
|
||||
self.min_rate = state_dict.pop("min_rate")
|
||||
self.total_steps = state_dict.pop("total_steps")
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
@SchedulerFactory.register("sgdr")
|
||||
class SGDRScheduler(BaseScheduler):
|
||||
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
warmup_steps: int,
|
||||
cycle_length: int,
|
||||
min_rate: float = 0.05,
|
||||
t_mult: int = 2,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.cycle_length = cycle_length
|
||||
self.min_rate = min_rate
|
||||
self.t_mult = t_mult
|
||||
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
# warmup
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
||||
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||
|
||||
# SGDR
|
||||
steps_since_warmup = self.last_epoch - self.warmup_steps
|
||||
|
||||
# 1. Calculate current cycle and position within cycle
|
||||
current_cycle_length = self.cycle_length
|
||||
total_cycles_length = 0
|
||||
cycle_num = 0
|
||||
|
||||
while total_cycles_length + current_cycle_length <= steps_since_warmup:
|
||||
total_cycles_length += current_cycle_length
|
||||
current_cycle_length *= self.t_mult
|
||||
cycle_num += 1
|
||||
|
||||
steps_in_cycle = steps_since_warmup - total_cycles_length
|
||||
|
||||
# 2. Cosine annealing within the current cycle
|
||||
cosine_factor = 0.5 * (
|
||||
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
|
||||
)
|
||||
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
||||
|
||||
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a dict."""
|
||||
state = super().state_dict()
|
||||
state.update(
|
||||
{
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"cycle_length": self.cycle_length,
|
||||
"min_rate": self.min_rate,
|
||||
"t_mult": self.t_mult,
|
||||
}
|
||||
)
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the scheduler's state."""
|
||||
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||
self.cycle_length = state_dict.pop("cycle_length")
|
||||
self.min_rate = state_dict.pop("min_rate")
|
||||
self.t_mult = state_dict.pop("t_mult")
|
||||
super().load_state_dict(state_dict)
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
"""Training strategy implementations with factory pattern."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||
"""Create a frozen reference model from model_fn + full state dict."""
|
||||
ref_model = model_fn()
|
||||
ref_model.load_state_dict(state_dict)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
return ref_model
|
||||
|
||||
|
||||
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
||||
"""Move batch tensors to specified device with non-blocking transfer."""
|
||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||
|
||||
|
||||
def get_logprobs(
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
reduction: str,
|
||||
):
|
||||
"""Compute token-wise log probabilities from model outputs.
|
||||
|
||||
Args:
|
||||
model: The language model
|
||||
input_ids: Input token IDs of shape [batch_size, seq_len]
|
||||
mask: Attention mask of shape [batch_size, seq_len]
|
||||
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
||||
|
||||
Returns:
|
||||
Log probabilities with reduction applied over sequence dimension
|
||||
"""
|
||||
allowed_reductions = ["mean", "sum", "none"]
|
||||
if reduction not in allowed_reductions:
|
||||
raise ValueError(
|
||||
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
|
||||
)
|
||||
|
||||
shifted_input_ids = input_ids[:, 1:]
|
||||
shifted_mask = mask[:, 1:]
|
||||
|
||||
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
|
||||
token_logprobs = torch.gather(
|
||||
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
if reduction == "mean":
|
||||
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
|
||||
dim=-1
|
||||
).clamp(min=1.0)
|
||||
elif reduction == "sum":
|
||||
return (token_logprobs * shifted_mask).sum(dim=-1)
|
||||
else:
|
||||
return token_logprobs * shifted_mask
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""Abstract base class for training strategies."""
|
||||
|
||||
def __init__(
|
||||
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.executor = kwargs.pop("executor", None)
|
||||
self.model_fn = kwargs.pop("model_fn", None)
|
||||
self.extra_kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Compute loss for the given batch.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing batch tensors
|
||||
|
||||
Returns:
|
||||
Computed loss tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Allow calling strategy directly as a callable."""
|
||||
return self.compute_loss(batch)
|
||||
|
||||
|
||||
class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||
"""Factory class for creating training strategy instances.
|
||||
|
||||
Supports decorator-based registration for extensible strategy types.
|
||||
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
||||
|
||||
Example usage:
|
||||
@StrategyFactory.register("custom")
|
||||
class CustomStrategy(BaseStrategy):
|
||||
...
|
||||
|
||||
strategy = StrategyFactory.create("custom", model, device)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, strategy_cls: type):
|
||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
||||
"""Create a strategy instance based on training type.
|
||||
|
||||
Args:
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
model: Model instance for the strategy
|
||||
device: Device to run the strategy on
|
||||
**kwargs: Additional arguments passed to strategy constructor
|
||||
|
||||
Returns:
|
||||
Strategy instance
|
||||
"""
|
||||
return super().create(train_type, model, device, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_strategies(cls) -> list:
|
||||
"""Return list of registered strategy names."""
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
# ============== Strategy Classes ==============
|
||||
# All strategies are registered at class definition time using the decorator
|
||||
|
||||
|
||||
@StrategyFactory.register("seq")
|
||||
class SEQStrategy(BaseStrategy):
|
||||
"""Standard next-token prediction training strategy.
|
||||
|
||||
Computes cross-entropy loss for next token prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten(),
|
||||
label_smoothing=self.label_smoothing,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@StrategyFactory.register("sft")
|
||||
class SFTStrategy(BaseStrategy):
|
||||
"""Supervised Fine-tuning strategy with loss masking.
|
||||
|
||||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids, loss_mask = (
|
||||
batch["input_ids"],
|
||||
batch["target_ids"],
|
||||
batch["loss_mask"],
|
||||
)
|
||||
|
||||
ignore_index = -100
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten(),
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=self.label_smoothing,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@StrategyFactory.register("dpo")
|
||||
class DPOStrategy(BaseStrategy):
|
||||
"""Direct Preference Optimization strategy.
|
||||
|
||||
Implements the DPO loss from the paper "Direct Preference Optimization".
|
||||
Uses a reference model to compute KL divergence penalty.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
beta: float = 0.1,
|
||||
reduction: str = "mean",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
).to(device=self.device)
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
||||
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||
|
||||
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||
|
||||
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
|
||||
|
||||
with torch.no_grad():
|
||||
log_ref = get_logprobs(
|
||||
self.ref_model, concat_ids, concat_mask, self.reduction
|
||||
)
|
||||
|
||||
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
||||
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
|
||||
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
|
||||
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
|
||||
|
||||
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
||||
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
||||
|
||||
ratio_diff = pi_log_ratio - ref_log_ratio
|
||||
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
||||
|
||||
return dpo_loss
|
||||
|
||||
|
||||
@StrategyFactory.register("grpo")
|
||||
class GRPOStrategy(BaseStrategy):
|
||||
"""Group Relative Policy Optimization strategy.
|
||||
|
||||
On-policy GRPO following DeepSeek-R1: the policy model is updated while
|
||||
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
|
||||
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
clip_eps: float = 0.2,
|
||||
kl_coef: float = 0.01,
|
||||
group_size: int = 4,
|
||||
reduction: str = "mean",
|
||||
sync_interval: int = 200,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
).to(device=self.device)
|
||||
self.clip_eps = clip_eps
|
||||
self.kl_coef = kl_coef
|
||||
self.group_size = group_size
|
||||
self.reduction = reduction
|
||||
self.sync_interval = sync_interval
|
||||
self._step = 0
|
||||
|
||||
def sync_ref_model(self):
|
||||
"""Copy current model weights to ref model."""
|
||||
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
self._step += 1
|
||||
if self._step % self.sync_interval == 0:
|
||||
self.sync_ref_model()
|
||||
|
||||
batch = move_to_device(batch, self.device)
|
||||
prompts = batch["prompts"]
|
||||
responses = batch["responses"]
|
||||
masks = batch["masks"]
|
||||
rewards = batch["rewards"]
|
||||
|
||||
batch_size, group_size, response_len = responses.shape
|
||||
responses_flat = responses.view(-1, response_len)
|
||||
masks_flat = masks.view(-1, response_len)
|
||||
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
||||
|
||||
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||
|
||||
log_probs_policy = get_logprobs(
|
||||
self.model, full_sequences, full_masks, self.reduction
|
||||
)
|
||||
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
||||
|
||||
with torch.no_grad():
|
||||
log_probs_ref = get_logprobs(
|
||||
self.ref_model, full_sequences, full_masks, self.reduction
|
||||
)
|
||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||
|
||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||
mean = rewards.mean(dim=-1, keepdim=True)
|
||||
std = rewards.std(dim=-1, keepdim=True)
|
||||
advantages = (rewards - mean) / (std + eps)
|
||||
|
||||
ratio = torch.exp(log_probs_policy - log_probs_ref)
|
||||
|
||||
surr1 = ratio * advantages
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
||||
total_loss = policy_loss + kl_penalty
|
||||
|
||||
return total_loss
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.parallel.setup import get_current_device, get_rank
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
ctx_get_grad_mean,
|
||||
ctx_get_grad_min,
|
||||
ctx_get_grad_nan_num,
|
||||
ctx_get_grad_norm,
|
||||
ctx_get_grad_std,
|
||||
ctx_get_loss,
|
||||
ctx_get_lr,
|
||||
ctx_get_val_loss,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TrainCallback(Protocol):
|
||||
"""
|
||||
Callback interface for trainer.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
"""Called at the beginning of training."""
|
||||
|
||||
def on_train_end(self, context: TrainContext):
|
||||
"""Called at the end of training."""
|
||||
|
||||
def on_epoch_begin(self, context: TrainContext):
|
||||
"""Called at the beginning of each epoch."""
|
||||
|
||||
def on_epoch_end(self, context: TrainContext):
|
||||
"""Called at the end of each epoch."""
|
||||
|
||||
def on_batch_begin(self, context: TrainContext):
|
||||
"""Called at the beginning of each batch."""
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
"""Called at the end of each batch."""
|
||||
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
"""Called on every optimizer step (sync step only)."""
|
||||
|
||||
def on_error(self, context: TrainContext):
|
||||
"""Called when an error occurs during training."""
|
||||
|
||||
|
||||
class CallbackFactory(BaseFactory[TrainCallback]):
|
||||
"""Factory for registering and creating training callbacks.
|
||||
|
||||
Example:
|
||||
@CallbackFactory.register("my_callback")
|
||||
class MyCallback(TrainCallback):
|
||||
...
|
||||
|
||||
callback = CallbackFactory.create("my_callback", **kwargs)
|
||||
"""
|
||||
|
||||
|
||||
@CallbackFactory.register("gradient_clipping")
|
||||
class GradientClippingCallback(TrainCallback):
|
||||
"""
|
||||
Gradient clipping callback for trainer.
|
||||
"""
|
||||
|
||||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
@CallbackFactory.register("gradient_checkpointing")
|
||||
class GradientCheckpointingCallback(TrainCallback):
|
||||
"""
|
||||
Activation checkpointing callback — trades compute for memory
|
||||
by recomputing specified module activations during the backward pass.
|
||||
|
||||
Args:
|
||||
modules: Module types to apply checkpointing to.
|
||||
"""
|
||||
|
||||
def __init__(self, modules: Optional[List[type]] = None):
|
||||
self.modules = tuple(modules) if modules else ()
|
||||
|
||||
def _enable(self, module: nn.Module):
|
||||
if self.modules and isinstance(module, self.modules):
|
||||
fn = module.forward
|
||||
module._original_forward = fn
|
||||
module.forward = lambda *a, **kw: torch_checkpoint(
|
||||
fn, *a, use_reentrant=False, **kw
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _disable(module: nn.Module):
|
||||
if hasattr(module, "_original_forward"):
|
||||
module.forward = module._original_forward
|
||||
del module._original_forward
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
context.model.apply(self._enable)
|
||||
logger.info("Gradient checkpointing enabled")
|
||||
|
||||
def on_train_end(self, context: TrainContext):
|
||||
context.model.apply(self._disable)
|
||||
|
||||
|
||||
@CallbackFactory.register("checkpoint")
|
||||
class CheckpointCallback(TrainCallback):
|
||||
"""
|
||||
Checkpoint callback for trainer.
|
||||
"""
|
||||
|
||||
extra_keys = ("optimizer", "scheduler")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
interval: int,
|
||||
weight_only: bool = False,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
def _save_checkpoint(self, context: TrainContext):
|
||||
state_dict = context.executor.unwrap_model(context.model)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
if get_rank() == 0:
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
config=context.model_config,
|
||||
)
|
||||
context.checkpoint.save(save_path)
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
self._save_checkpoint(context)
|
||||
|
||||
def on_train_end(self, context: TrainContext):
|
||||
if context.iteration != self.last_ckpt_iter:
|
||||
self._save_checkpoint(context)
|
||||
|
||||
def on_error(self, context: TrainContext):
|
||||
self._save_checkpoint(context)
|
||||
|
||||
@staticmethod
|
||||
def save_extra(context: TrainContext) -> dict:
|
||||
extra = {}
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
obj = getattr(context, name, None)
|
||||
if obj:
|
||||
extra[name] = obj.state_dict()
|
||||
return extra
|
||||
|
||||
|
||||
@CallbackFactory.register("progress_bar")
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
"""
|
||||
Progress bar callback for trainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
||||
):
|
||||
self.num_epoch = num_epoch
|
||||
self.log_interval = log_interval
|
||||
self.file = file
|
||||
self.progress_bar: tqdm = None
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_epoch_begin(self, context: TrainContext):
|
||||
self.progress_bar = tqdm(
|
||||
context.dataloader,
|
||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||
dynamic_ncols=True,
|
||||
file=self.file or sys.stdout,
|
||||
)
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
postfix = {
|
||||
"loss": f"{context.loss:.4f}",
|
||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||
}
|
||||
if context.val_loss > 0:
|
||||
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
||||
self.progress_bar.set_postfix(postfix)
|
||||
self.progress_bar.update(1)
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_epoch_end(self, context: TrainContext):
|
||||
_ = context
|
||||
if self.progress_bar:
|
||||
self.progress_bar.close()
|
||||
|
||||
|
||||
@CallbackFactory.register("metric_logger")
|
||||
class MetricLoggerCallback(TrainCallback):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: str,
|
||||
save_interval: int,
|
||||
log_interval: int = 10,
|
||||
metrics: List[str] = None,
|
||||
):
|
||||
self.last_log_iter = 0
|
||||
self.save_interval = save_interval
|
||||
self.log_interval = log_interval
|
||||
self.metrics = metrics or ["loss", "lr"]
|
||||
|
||||
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.log_cache = []
|
||||
|
||||
self._metric_funcs = {
|
||||
"loss": ctx_get_loss,
|
||||
"lr": ctx_get_lr,
|
||||
"val_loss": ctx_get_val_loss,
|
||||
"grad_norm": ctx_get_grad_norm,
|
||||
"grad_std": ctx_get_grad_std,
|
||||
"grad_max": ctx_get_grad_max,
|
||||
"grad_min": ctx_get_grad_min,
|
||||
"grad_mean": ctx_get_grad_mean,
|
||||
"grad_nan_num": ctx_get_grad_nan_num,
|
||||
}
|
||||
|
||||
def _get_log_data(self, context: TrainContext):
|
||||
return {
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"epoch": context.epoch,
|
||||
"iter": context.iteration,
|
||||
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||
}
|
||||
|
||||
@only_on_rank(0)
|
||||
def _add_log(self, log_data):
|
||||
self.log_cache.append(log_data)
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_log(self, epoch, iter):
|
||||
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
||||
|
||||
with open(log_file, "w") as f:
|
||||
for log in self.log_cache:
|
||||
f.write(json.dumps(log) + "\n")
|
||||
|
||||
def on_batch_end(self, context):
|
||||
if context.iteration % self.log_interval == 0:
|
||||
log_data = self._get_log_data(context)
|
||||
self._add_log(log_data)
|
||||
|
||||
if context.iteration - self.last_log_iter >= self.save_interval:
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
self.last_log_iter = context.iteration
|
||||
|
||||
def on_train_end(self, context):
|
||||
if context.iteration != self.last_log_iter:
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
||||
def on_error(self, context):
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
||||
|
||||
@CallbackFactory.register("validation")
|
||||
class ValidationCallback(TrainCallback):
|
||||
def _run_validation(self, context: TrainContext):
|
||||
context.model.eval()
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in context.val_dataloader:
|
||||
loss = context.strategy(batch)
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
if context.world_size > 1 and dist.is_initialized():
|
||||
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
|
||||
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
||||
avg_loss = loss_tensor.item()
|
||||
|
||||
context.val_loss = avg_loss
|
||||
context.model.train()
|
||||
|
||||
step_count = context.iteration // context.config.grad_accum_steps
|
||||
logger.info(
|
||||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||
)
|
||||
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
if context.val_dataloader is None:
|
||||
return
|
||||
cfg = context.config
|
||||
if cfg.val_step <= 0:
|
||||
return
|
||||
step_count = context.iteration // cfg.grad_accum_steps
|
||||
if step_count % cfg.val_step == 0:
|
||||
self._run_validation(context)
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.dataset import ResumableDistributedSampler
|
||||
from astrai.model.components.lora import inject_lora
|
||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainContext:
|
||||
model: nn.Module = field(default=None)
|
||||
strategy: BaseStrategy = field(default=None)
|
||||
dataloader: DataLoader = field(default=None)
|
||||
optimizer: OptimizerProtocol = field(default=None)
|
||||
scheduler: SchedulerProtocol = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
config: TrainConfig = field(default=None)
|
||||
model_config: dict = field(default_factory=dict)
|
||||
executor: BaseExecutor = field(default=None)
|
||||
|
||||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
val_dataloader: DataLoader = field(default=None)
|
||||
val_loss: float = field(default=0.0)
|
||||
|
||||
world_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class TrainContextBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainConfig,
|
||||
):
|
||||
self.config = config
|
||||
self._resume_dir: Optional[str] = None
|
||||
|
||||
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
||||
self._resume_dir = resume_dir
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
cfg = self.config
|
||||
device = get_current_device()
|
||||
|
||||
executor = ExecutorFactory.create(
|
||||
cfg.parallel_mode,
|
||||
grad_accum_steps=cfg.grad_accum_steps,
|
||||
**cfg.executor_kwargs,
|
||||
)
|
||||
|
||||
model = cfg.model_fn()
|
||||
model = model.to(device=device)
|
||||
|
||||
model_config = {}
|
||||
if self._resume_dir:
|
||||
config_path = Path(self._resume_dir) / "config.json"
|
||||
if config_path.exists():
|
||||
model_config = load_json(config_path)
|
||||
|
||||
if not model_config and hasattr(model, "config"):
|
||||
model_config = model.config.to_dict()
|
||||
|
||||
context = TrainContext(
|
||||
model=model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=cfg,
|
||||
model_config=model_config,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
if self._resume_dir is not None:
|
||||
resume_path = Path(self._resume_dir)
|
||||
if (resume_path / "meta.json").exists():
|
||||
checkpoint = Checkpoint.load(self._resume_dir)
|
||||
state_dict = checkpoint.state_dict
|
||||
if checkpoint.config:
|
||||
context.model_config = checkpoint.config
|
||||
else:
|
||||
checkpoint = None
|
||||
state_dict = load_model_weights(self._resume_dir)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if checkpoint is not None:
|
||||
context.epoch = cfg.start_epoch
|
||||
context.iteration = cfg.start_batch
|
||||
context.checkpoint = checkpoint
|
||||
|
||||
if cfg.lora is not None:
|
||||
inject_lora(
|
||||
model,
|
||||
r=cfg.lora.r,
|
||||
alpha=cfg.lora.alpha,
|
||||
target_modules=set(cfg.lora.target_modules),
|
||||
)
|
||||
|
||||
context.optimizer = cfg.optimizer_fn(model)
|
||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.dataset,
|
||||
start_epoch=context.epoch,
|
||||
start_iter=sampler_offset,
|
||||
seed=cfg.random_seed,
|
||||
)
|
||||
context.dataloader = DataLoader(
|
||||
cfg.dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=cfg.pin_memory,
|
||||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
if cfg.val_dataset is not None:
|
||||
val_sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.val_dataset,
|
||||
start_epoch=0,
|
||||
start_iter=0,
|
||||
seed=cfg.random_seed,
|
||||
shuffle=False,
|
||||
)
|
||||
context.val_dataloader = DataLoader(
|
||||
cfg.val_dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=val_sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=cfg.pin_memory,
|
||||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||
executor.prepare(
|
||||
model,
|
||||
context.optimizer,
|
||||
context.dataloader,
|
||||
context.scheduler,
|
||||
)
|
||||
)
|
||||
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
extra = context.checkpoint.extra
|
||||
for name in ("optimizer", "scheduler"):
|
||||
if name in extra:
|
||||
obj = getattr(context, name, None)
|
||||
if obj is not None:
|
||||
obj.load_state_dict(extra[name])
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=cfg.strategy,
|
||||
device=device,
|
||||
executor=executor,
|
||||
model_fn=cfg.model_fn,
|
||||
**cfg.extra_kwargs,
|
||||
)
|
||||
|
||||
return context
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.trainer.train_callback import (
|
||||
CallbackFactory,
|
||||
TrainCallback,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
|
||||
):
|
||||
self.train_config = train_config
|
||||
default_callbacks = self._get_default_callbacks()
|
||||
self.callbacks = (
|
||||
default_callbacks + callbacks if callbacks else default_callbacks
|
||||
)
|
||||
|
||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||
cfg = self.train_config
|
||||
callbacks = [
|
||||
CallbackFactory.create(
|
||||
"gradient_checkpointing",
|
||||
modules=cfg.gradient_checkpointing_modules,
|
||||
),
|
||||
CallbackFactory.create(
|
||||
"checkpoint",
|
||||
cfg.ckpt_dir,
|
||||
cfg.ckpt_interval,
|
||||
),
|
||||
CallbackFactory.create(
|
||||
"metric_logger",
|
||||
log_dir=cfg.log_dir,
|
||||
save_interval=cfg.ckpt_interval,
|
||||
log_interval=cfg.log_interval,
|
||||
metrics=cfg.metrics,
|
||||
),
|
||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("validation"),
|
||||
]
|
||||
return callbacks
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
if method:
|
||||
method(context)
|
||||
|
||||
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
||||
context = (
|
||||
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
||||
)
|
||||
executor = context.executor
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
||||
try:
|
||||
context.model.train()
|
||||
|
||||
for epoch in range(context.epoch, context.config.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
|
||||
with executor.accumulate(context.model):
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / executor.grad_accum_steps
|
||||
executor.backward(stand_loss)
|
||||
context.iteration += 1
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
if executor.sync_gradients:
|
||||
self._call_callbacks("on_optimizer_step", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Training failed: %s", str(e), exc_info=True)
|
||||
self._call_callbacks("on_error", context)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks("on_train_end", context)
|
||||
|
||||
def train(self, resume_dir: Optional[str] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._trainer_loop,
|
||||
backend=cfg.backend,
|
||||
world_size=cfg.nprocs,
|
||||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
resume_dir=resume_dir,
|
||||
)
|
||||
215
benchmark.py
215
benchmark.py
|
|
@ -1,215 +0,0 @@
|
|||
import torch
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from khaosz.core.transformer import TransformerConfig, Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
total_tokens: int
|
||||
total_time: float
|
||||
tokens_per_second: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class GenerationBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
config: TransformerConfig,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16
|
||||
):
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||
self.model.eval()
|
||||
|
||||
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list:
|
||||
"""初始化KV缓存"""
|
||||
kv_cache = []
|
||||
head_dim = self.config.n_dim // self.config.n_head
|
||||
for _ in range(self.config.n_layer):
|
||||
k_cache = torch.zeros(
|
||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
kv_cache.append((k_cache, v_cache))
|
||||
return kv_cache
|
||||
|
||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||
prompt_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.config.vocab_size,
|
||||
size=(batch_size, prompt_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
gen_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.config.vocab_size,
|
||||
size=(batch_size, total_length - prompt_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
return prompt_ids, gen_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_benchmark(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
prompt_length: int = 512,
|
||||
num_trials: int = 10,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
for _ in range(3):
|
||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||
_ = self.model(prompt_ids)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
total_time = 0.0
|
||||
total_tokens = batch_size * prompt_length * num_trials
|
||||
|
||||
for trial in range(num_trials):
|
||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
_ = self.model(prompt_ids)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||
f"({prompt_length / trial_time:.1f} tokens/s)")
|
||||
|
||||
return BenchmarkResult(
|
||||
total_tokens=total_tokens,
|
||||
total_time=total_time,
|
||||
tokens_per_second=total_tokens / total_time,
|
||||
metadata={
|
||||
"benchmark_type": "prefill",
|
||||
"batch_size": batch_size,
|
||||
"prompt_length": prompt_length,
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_decoding_benchmark(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
prompt_length: int = 512,
|
||||
gen_length: int = 128,
|
||||
num_trials: int = 5,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
total_time = 0.0
|
||||
total_tokens = batch_size * gen_length * num_trials
|
||||
|
||||
for trial in range(num_trials):
|
||||
|
||||
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
||||
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len)
|
||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
|
||||
current_pos = prompt_length
|
||||
for i in range(gen_length):
|
||||
input_token = gen_ids[:, i:i+1]
|
||||
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
|
||||
current_pos += 1
|
||||
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
f"({gen_length / trial_time:.1f} tokens/s)")
|
||||
|
||||
|
||||
return BenchmarkResult(
|
||||
total_tokens=total_tokens,
|
||||
total_time=total_time,
|
||||
tokens_per_second=total_tokens / total_time,
|
||||
metadata={
|
||||
"benchmark_type": "generation",
|
||||
"batch_size": batch_size,
|
||||
"prompt_length": prompt_length,
|
||||
"gen_length": gen_length,
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def print_benchmark_result(result: BenchmarkResult):
|
||||
"""打印基准测试结果"""
|
||||
benchmark_type = result.metadata["benchmark_type"]
|
||||
|
||||
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
||||
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
||||
print(f"Time Consumed: {result.total_time:.3f}s")
|
||||
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
||||
|
||||
if benchmark_type == "prefill":
|
||||
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
||||
elif benchmark_type == "generation":
|
||||
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
||||
|
||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = TransformerConfig(
|
||||
vocab_size=10000,
|
||||
n_dim=1536,
|
||||
n_head=24,
|
||||
n_kvhead=4,
|
||||
d_ffn=6912,
|
||||
m_len=2048,
|
||||
n_layer=24,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
batch_size=4,
|
||||
prompt_length=512,
|
||||
num_trials=5
|
||||
)
|
||||
print_benchmark_result(prefill_result)
|
||||
|
||||
gen_result = benchmark.run_decoding_benchmark(
|
||||
batch_size=4,
|
||||
prompt_length=512,
|
||||
gen_length=128,
|
||||
num_trials=5
|
||||
)
|
||||
print_benchmark_result(gen_result)
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
services:
|
||||
server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
user: "${UID:-1000}:${GID:-1000}"
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./params:/app/params:ro
|
||||
command: python -m scripts.tools.server --port 8000 --device cuda
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
restart: unless-stopped
|
||||
|
||||
server-cpu:
|
||||
profiles: [cpu]
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
user: "${UID:-1000}:${GID:-1000}"
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./params:/app/params:ro
|
||||
command: python -m scripts.tools.server --port 8000 --device cpu
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 120s
|
||||
restart: unless-stopped
|
||||
101
generate.py
101
generate.py
|
|
@ -1,101 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
import json
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
from khaosz import Khaosz
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def batch_generate(
|
||||
model: Khaosz,
|
||||
queries: List[str],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
batch_size: int,
|
||||
) -> List:
|
||||
assert batch_size > 0
|
||||
sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True)
|
||||
original_indices = {query: idx for idx, query in enumerate(queries)}
|
||||
|
||||
responses = [None] * len(queries)
|
||||
total_batches = (len(sorted_queries) + batch_size - 1) // batch_size
|
||||
|
||||
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
|
||||
batch_queries = sorted_queries[i: min(i + batch_size, len(queries))]
|
||||
if not isinstance(batch_queries, list):
|
||||
batch_queries = [batch_queries]
|
||||
|
||||
batch_responses = model.batch_generate(
|
||||
queries=batch_queries,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p
|
||||
)
|
||||
|
||||
for batch_query, batch_response in zip(batch_queries, batch_responses):
|
||||
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
|
||||
|
||||
for query, response in zip(batch_queries, batch_responses):
|
||||
original_idx = original_indices[query]
|
||||
responses[original_idx] = response
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
def processor(
|
||||
model: Khaosz,
|
||||
input_json_file: str,
|
||||
output_json_file: str,
|
||||
batch_size: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
question_key: str="question",
|
||||
):
|
||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
||||
input_dict = [json.loads(line) for line in f]
|
||||
queries = [item[question_key] for item in input_dict]
|
||||
|
||||
output_dict = batch_generate(
|
||||
model=model,
|
||||
queries=queries,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
with open(output_json_file, "w", encoding='utf-8') as f:
|
||||
json.dump(output_dict, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||
|
||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
||||
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
|
||||
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
|
||||
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
|
||||
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
|
||||
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
|
||||
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
||||
|
||||
args = parser.parse_args()
|
||||
model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
processor(
|
||||
model,
|
||||
input_json_file=args.input_json_file,
|
||||
output_json_file=args.output_json_file,
|
||||
question_key=args.question_key,
|
||||
batch_size=args.batch_size,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p
|
||||
)
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
__version__ = "1.2.2"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from khaosz.model import Khaosz
|
||||
from khaosz.core.transformer import Transformer, TransformerConfig
|
||||
from khaosz.utils.retriever import Retriever
|
||||
from khaosz.utils.splitter import (
|
||||
SemanticTextSplitter,
|
||||
PriorityTextSplitter
|
||||
)
|
||||
from khaosz.core.tokenizer import BpeTokenizer
|
||||
from khaosz.core.parameter import ParameterLoader
|
||||
from khaosz.core.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
from khaosz.trainer import (
|
||||
Trainer,
|
||||
DatasetLoader,
|
||||
TrainConfig,
|
||||
StrategyFactory,
|
||||
SchedulerFactory
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# model
|
||||
"Khaosz",
|
||||
|
||||
# module
|
||||
"Transformer",
|
||||
"TransformerConfig",
|
||||
"BpeTokenizer",
|
||||
"ParameterLoader",
|
||||
"TextGenerator",
|
||||
"ChatGenerator",
|
||||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"RetrievalGenerator",
|
||||
"EmbeddingEncoder",
|
||||
|
||||
# trainer
|
||||
"Trainer",
|
||||
"DatasetLoader",
|
||||
"TrainConfig",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
|
||||
# utils
|
||||
"Retriever",
|
||||
"SemanticTextSplitter",
|
||||
"PriorityTextSplitter",
|
||||
]
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
from khaosz.core.tokenizer import BpeTokenizer
|
||||
from khaosz.core.transformer import Transformer, TransformerConfig
|
||||
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
|
||||
from khaosz.core.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
"TransformerConfig",
|
||||
"BpeTokenizer",
|
||||
"ParameterLoader",
|
||||
"ModelParameter",
|
||||
"Checkpoint",
|
||||
"TextGenerator",
|
||||
"ChatGenerator",
|
||||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"RetrievalGenerator",
|
||||
"EmbeddingEncoder"
|
||||
]
|
||||
|
|
@ -1,568 +0,0 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator, Self
|
||||
from khaosz.core.parameter import ModelParameter
|
||||
|
||||
|
||||
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
|
||||
"""
|
||||
Build prompt for query and history
|
||||
|
||||
Args:
|
||||
query(str): query string
|
||||
history(Optional[List[Tuple[str, str]]]): history list of query and response
|
||||
|
||||
Returns:
|
||||
str: prompt string
|
||||
|
||||
"""
|
||||
prompt_parts = []
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
for his_query, his_response in history:
|
||||
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>")
|
||||
|
||||
if query is not None:
|
||||
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
|
||||
"""
|
||||
Pad a list of sequences to a fixed length.
|
||||
|
||||
Args:
|
||||
ids_list (List[List[int]]): A list of sequences.
|
||||
max_ids_len (int): The maximum length of sequences.
|
||||
pad_id (int): The id to pad sequences.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: A list of padded sequences.
|
||||
|
||||
"""
|
||||
new_ids_list = []
|
||||
for ids in ids_list:
|
||||
pad_len = max_ids_len - len(ids)
|
||||
padded_seq = [pad_id] * pad_len + ids
|
||||
new_ids_list.append(padded_seq)
|
||||
|
||||
return new_ids_list
|
||||
|
||||
def apply_sampling_strategies(
|
||||
logits: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
filter_value: float = -float("inf")
|
||||
) -> Tensor:
|
||||
"""
|
||||
Apply sampling strategies to the logits tensor.
|
||||
|
||||
Args:
|
||||
logits (Tensor): The logits tensor.
|
||||
temperature (float): The temperature parameter.
|
||||
top_k (int): The top-k parameter.
|
||||
top_p (float): The top-p parameter.
|
||||
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
||||
|
||||
Returns:
|
||||
Tensor: The sampled logits tensor.
|
||||
|
||||
"""
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1))
|
||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(
|
||||
dim=1,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove
|
||||
)
|
||||
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
batch_size: int,
|
||||
max_len: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.batch_size = batch_size
|
||||
self.max_len = max_len
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
|
||||
self._seq_mask: Tensor = None
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
self._kv_cache = []
|
||||
for _ in range(self.num_layers):
|
||||
k_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
self._kv_cache.append((k_cache, v_cache))
|
||||
|
||||
self._seq_mask = torch.ones(
|
||||
(self.batch_size, self.max_len),
|
||||
device=self.device, dtype=torch.bool
|
||||
)
|
||||
|
||||
def update(self, active_mask: Tensor):
|
||||
for i in range(self.num_layers):
|
||||
k_cache, v_cache = self._kv_cache[i]
|
||||
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
|
||||
self._kv_cache[i] = (new_k_cache, new_v_cache)
|
||||
|
||||
self._seq_mask = self._seq_mask[active_mask]
|
||||
|
||||
def reset(self, full_reset=False):
|
||||
if full_reset:
|
||||
self._kv_cache = None
|
||||
self._seq_mask = None
|
||||
else:
|
||||
self._initialize()
|
||||
|
||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
bool_mask = (input_ids != pad_id)
|
||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
||||
|
||||
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
|
||||
return self._kv_cache
|
||||
|
||||
def get_seq_mask(self) -> Tensor:
|
||||
return self._seq_mask
|
||||
|
||||
|
||||
class GeneratorCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||
start_pos: int = 0
|
||||
) -> Tuple[Tensor, int]:
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
cache_increase = input_ids.size(-1)
|
||||
|
||||
return logits, cache_increase
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
|
||||
class EmbeddingEncoderCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
with_batch = isinstance(sentence, list)
|
||||
ids = self.tokenizer.encode(sentence)
|
||||
batch_ids = ids if with_batch else [ids]
|
||||
max_model_len = self.config.m_len
|
||||
|
||||
all_fragments = []
|
||||
fragment_origin_idx = []
|
||||
|
||||
for i, seq in enumerate(batch_ids):
|
||||
if len(seq) > max_model_len:
|
||||
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
||||
all_fragments.extend(fragments)
|
||||
fragment_origin_idx.extend([i] * len(fragments))
|
||||
else:
|
||||
all_fragments.append(seq)
|
||||
fragment_origin_idx.append(i)
|
||||
|
||||
#if empty fragments
|
||||
if not all_fragments or not ids:
|
||||
return [] if with_batch else torch.tensor([])
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
||||
|
||||
padded_ids = []
|
||||
masks = []
|
||||
for seq in all_fragments:
|
||||
pad_len = max_len - len(seq)
|
||||
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
||||
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
||||
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
||||
# [num_fragments, seq_len, hidden_size]
|
||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||
|
||||
sentence_embs: List[Tensor] = []
|
||||
for i in range(len(batch_ids)):
|
||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
||||
if indices is not None:
|
||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||
sentence_embs.append(emb.flatten())
|
||||
|
||||
if with_batch:
|
||||
return [emb.flatten() for emb in sentence_embs]
|
||||
else:
|
||||
return sentence_embs[0].flatten()
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
|
||||
class TextGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> str:
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=1,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
|
||||
ids = self.tokenizer.encode(query)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
|
||||
while len(ids) < self.config.m_len:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_ids,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
break
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
class ChatGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> str:
|
||||
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=1,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
cpy_history = history.copy()
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
|
||||
|
||||
while len(ids) < self.config.m_len:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_ids,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
break
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
cpy_history.append((query, response))
|
||||
|
||||
return response, cpy_history
|
||||
|
||||
|
||||
class StreamGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=1,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
cpy_history = history.copy()
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
|
||||
|
||||
while len(ids) < self.config.m_len:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_ids,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
yield response, cpy_history + [(query, response)]
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
yield response + "\n", cpy_history + [(query, response)]
|
||||
break
|
||||
|
||||
|
||||
class BatchGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
queries: List[str],
|
||||
histories: List[List[Tuple[str, str]]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float
|
||||
) -> List[str]:
|
||||
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
batch_size = len(queries)
|
||||
if histories is None:
|
||||
histories = [[] for _ in range(batch_size)]
|
||||
|
||||
prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
|
||||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||
max_ids_len = max(len(ids) for ids in ids_list)
|
||||
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=batch_size,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
|
||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||
activate_task_mask = [True] * batch_size
|
||||
|
||||
start_cache_pos = max_ids_len
|
||||
cur_cache_pos = 0
|
||||
|
||||
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
attn_mask =cache_manager.get_seq_mask()
|
||||
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_tensor,
|
||||
attn_mask=attn_mask,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
|
||||
cur_cache_pos += cache_increase
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
active_mask = []
|
||||
c_ids = 0
|
||||
|
||||
for i in range(batch_size):
|
||||
if activate_task_mask[i]:
|
||||
token = next_token_id[c_ids, :].item()
|
||||
ids_list[i].append(token)
|
||||
c_ids += 1
|
||||
|
||||
is_active = not token in self.tokenizer.stop_ids
|
||||
activate_task_mask[i] = is_active
|
||||
active_mask.append(is_active)
|
||||
|
||||
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
|
||||
cache_manager.update(active_mask)
|
||||
input_tensor = next_token_id[active_mask, :]
|
||||
|
||||
max_ids_len += 1
|
||||
|
||||
|
||||
responses = [str()] * batch_size
|
||||
for i in range(batch_size):
|
||||
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
|
||||
histories[i].append((queries[i], responses[i]))
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
|
||||
class RetrievalGenerator(GeneratorCore):
|
||||
def __init__(self, retriever_parameter: ModelParameter):
|
||||
super().__init__(retriever_parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
retrieved: List[str],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> str:
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
|
||||
retrieved_query = f"{retrieved}<eos>\n\n根据以上内容回答: {query}" if retrieved else query
|
||||
parameter = ModelParameter(self.model, self.tokenizer, self.config)
|
||||
|
||||
return ChatGenerator(parameter).generate(
|
||||
retrieved_query,
|
||||
history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
return super().encode(sentence)
|
||||
|
||||
|
|
@ -1,237 +0,0 @@
|
|||
import pickle as pkl
|
||||
import matplotlib.pyplot as plt
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Self, Union
|
||||
from pathlib import Path
|
||||
|
||||
from khaosz.core.tokenizer import BpeTokenizer
|
||||
from khaosz.core.transformer import TransformerConfig, Transformer
|
||||
|
||||
|
||||
class BaseModelIO:
|
||||
"""Base class for model I/O operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[nn.Module] = None,
|
||||
tokenizer: Optional[BpeTokenizer] = None,
|
||||
config: Optional[TransformerConfig] = None
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer or BpeTokenizer()
|
||||
self.config = config or TransformerConfig()
|
||||
|
||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||
"""Get standardized file paths for model components."""
|
||||
dir_path = Path(directory)
|
||||
return {
|
||||
"model": dir_path / "model.safetensors",
|
||||
"config": dir_path / "config.json",
|
||||
"tokenizer": dir_path / "tokenizer.json"
|
||||
}
|
||||
|
||||
def save_components(self, save_dir: Union[str, Path]):
|
||||
"""Save core model components."""
|
||||
paths = self._get_file_paths(save_dir)
|
||||
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.model is not None:
|
||||
st.save_file(self.model.state_dict(), str(paths["model"]))
|
||||
self.config.save(str(paths["config"]))
|
||||
self.tokenizer.save(str(paths["tokenizer"]))
|
||||
|
||||
def load_components(self, load_dir: Union[str, Path]) -> Self:
|
||||
"""Load core model components."""
|
||||
paths = self._get_file_paths(load_dir)
|
||||
|
||||
self.config.load(str(paths["config"]))
|
||||
self.tokenizer.load(str(paths["tokenizer"]))
|
||||
|
||||
if paths["model"].exists():
|
||||
state_dict = st.load_file(str(paths["model"]))
|
||||
if self.model is None:
|
||||
self.model = Transformer(self.config)
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
"""Move model to device."""
|
||||
if self.model is not None:
|
||||
self.model.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameter(BaseModelIO):
|
||||
"""Container for model parameters with serialization capabilities."""
|
||||
|
||||
model: Optional[nn.Module] = field(
|
||||
default=None,
|
||||
metadata={"help": "Transformer model."}
|
||||
)
|
||||
tokenizer: BpeTokenizer = field(
|
||||
default_factory=BpeTokenizer,
|
||||
metadata={"help": "Tokenizer for the model."}
|
||||
)
|
||||
config: TransformerConfig = field(
|
||||
default_factory=TransformerConfig,
|
||||
metadata={"help": "Transformer model configuration."}
|
||||
)
|
||||
|
||||
def save(self, save_dir: Union[str, Path]):
|
||||
self.save_components(save_dir)
|
||||
|
||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||
return self.load_components(load_dir)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint(BaseModelIO):
|
||||
"""Extended model parameters with training state."""
|
||||
|
||||
model: Optional[nn.Module] = field(
|
||||
default=None,
|
||||
metadata={"help": "Transformer model."}
|
||||
)
|
||||
tokenizer: BpeTokenizer = field(
|
||||
default_factory=BpeTokenizer,
|
||||
metadata={"help": "Tokenizer for the model."}
|
||||
)
|
||||
config: TransformerConfig = field(
|
||||
default_factory=TransformerConfig,
|
||||
metadata={"help": "Transformer model configuration."}
|
||||
)
|
||||
optim_state: Dict[str, Any] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer state."}
|
||||
)
|
||||
sampler_state: Dict[str, Any] = field(
|
||||
default=None,
|
||||
metadata={"help": "Sampler state."}
|
||||
)
|
||||
loss_list: List[float] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of training losses."}
|
||||
)
|
||||
|
||||
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||
paths = self._get_file_paths(directory)
|
||||
paths.update({
|
||||
"loss_list": paths["model"].parent / "loss.pkl",
|
||||
"loss_plot": paths["model"].parent / "loss.png",
|
||||
"optim_state": paths["model"].parent / "optim_state.pkl",
|
||||
"sampler_state": paths["model"].parent / "sampler_state.pkl"
|
||||
})
|
||||
return paths
|
||||
|
||||
def save_training_state(self, save_dir: Union[str, Path]):
|
||||
paths = self._get_training_paths(save_dir)
|
||||
|
||||
# Save loss plot
|
||||
self._plot_loss(str(paths["loss_plot"]))
|
||||
|
||||
# Save loss list
|
||||
with open(str(paths["loss_list"]), "wb") as f:
|
||||
pkl.dump(self.loss_list, f)
|
||||
|
||||
# Save optimizer state
|
||||
with open(str(paths["optim_state"]), "wb") as f:
|
||||
pkl.dump(self.optim_state, f)
|
||||
|
||||
# Save sampler state
|
||||
with open(str(paths["sampler_state"]), "wb") as f:
|
||||
pkl.dump(self.sampler_state, f)
|
||||
|
||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||
paths = self._get_training_paths(load_dir)
|
||||
|
||||
# Load loss list
|
||||
if paths["loss_list"].exists():
|
||||
with open(str(paths["loss_list"]), "rb") as f:
|
||||
self.loss_list = pkl.load(f)
|
||||
|
||||
# Load optimizer state
|
||||
if paths["optim_state"].exists():
|
||||
with open(str(paths["optim_state"]), "rb") as f:
|
||||
self.optim_state = pkl.load(f)
|
||||
|
||||
# Load sampler state
|
||||
if paths["sampler_state"].exists():
|
||||
with open(str(paths["sampler_state"]), "rb") as f:
|
||||
self.sampler_state = pkl.load(f)
|
||||
|
||||
return self
|
||||
|
||||
def _plot_loss(self, save_path: str):
|
||||
"""Plot and save loss curve."""
|
||||
if not self.loss_list:
|
||||
return
|
||||
|
||||
current_iter = len(self.loss_list)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(self.loss_list)
|
||||
plt.title(f"Training Loss - Iteration {current_iter}")
|
||||
plt.xlabel("Batch")
|
||||
plt.ylabel("Loss")
|
||||
plt.grid(True)
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
def save(self, save_dir: Union[str, Path]):
|
||||
"""Save complete checkpoint."""
|
||||
self.save_components(save_dir)
|
||||
self.save_training_state(save_dir)
|
||||
|
||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||
"""Load complete checkpoint."""
|
||||
self.load_components(load_dir)
|
||||
self.load_training_state(load_dir)
|
||||
return self
|
||||
|
||||
|
||||
class ParameterLoader:
|
||||
"""Factory class for loading model parameters or checkpoints."""
|
||||
|
||||
@staticmethod
|
||||
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
|
||||
"""Load either ModelParameter or Checkpoint based on directory contents."""
|
||||
load_dir = Path(load_dir)
|
||||
|
||||
# Check for training-specific files
|
||||
loss_file = load_dir / "loss.pkl"
|
||||
has_training_data = loss_file.exists()
|
||||
|
||||
# Create appropriate instance
|
||||
if has_training_data:
|
||||
checkpoint = Checkpoint()
|
||||
checkpoint.load(str(load_dir))
|
||||
return checkpoint
|
||||
else:
|
||||
params = ModelParameter()
|
||||
params.load(str(load_dir))
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def create_checkpoint(
|
||||
model: nn.Module,
|
||||
tokenizer: BpeTokenizer,
|
||||
config: TransformerConfig,
|
||||
loss_list: Optional[list[float]] = None,
|
||||
optimizer: Optional[optim.Optimizer] = None,
|
||||
) -> Checkpoint:
|
||||
"""Convenience method to create a training checkpoint."""
|
||||
return Checkpoint(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
loss_list=loss_list or [],
|
||||
optimizer_state=optimizer
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
from tokenizers import Tokenizer, Encoding
|
||||
from tokenizers import decoders, processors, normalizers, pre_tokenizers
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
class BpeTokenizer:
|
||||
def __init__(self, path=None):
|
||||
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
||||
self._special_tokens = ["<|user|>", "<|system|>"]
|
||||
model = BPE()
|
||||
tokenizer = Tokenizer(model)
|
||||
tokenizer.normalizer = normalizers.Sequence([
|
||||
normalizers.NFC()
|
||||
])
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||
pre_tokenizers.Punctuation(behavior="isolated"),
|
||||
pre_tokenizers.Metaspace(prepend_scheme="never"),
|
||||
pre_tokenizers.Split(pattern=r"(\d+|[a-zA-Z]+|(?:'s|'t|'re|'ve|'m|'ll|'d))", behavior="isolated"),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
||||
])
|
||||
tokenizer.decoder = decoders.Sequence([
|
||||
decoders.ByteLevel(),
|
||||
decoders.Metaspace(prepend_scheme="never")
|
||||
])
|
||||
tokenizer.post_processor = processors.Sequence([
|
||||
processors.ByteLevel(trim_offsets=False)
|
||||
])
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
if path is not None:
|
||||
self._tokenizer = Tokenizer.from_file(path)
|
||||
|
||||
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int) -> tuple:
|
||||
assert reserved_token_size > len(self._special_tokens)
|
||||
reserved_tokens = [f"<|rsv{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
|
||||
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
|
||||
|
||||
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||
min_size = len(alphabet) + len(self._control_tokens)
|
||||
assert detail_vocab_size > min_size
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=detail_vocab_size,
|
||||
min_frequency=min_freq,
|
||||
limit_alphabet=detail_vocab_size // 4,
|
||||
max_token_length=18,
|
||||
special_tokens=self._control_tokens,
|
||||
show_progress=True,
|
||||
initial_alphabet=alphabet,
|
||||
)
|
||||
|
||||
return trainer, detail_vocab_size, reserved_tokens
|
||||
|
||||
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
|
||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||
vocab_size=vocab_size,
|
||||
min_freq=min_freq,
|
||||
reserved_token_size=reserved_token_size
|
||||
)
|
||||
self._tokenizer.train(files=files, trainer=trainer)
|
||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||
|
||||
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
|
||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||
vocab_size=vocab_size,
|
||||
min_freq=min_freq,
|
||||
reserved_token_size=reserved_token_size
|
||||
)
|
||||
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||
|
||||
def save(self, path):
|
||||
self._tokenizer.save(path)
|
||||
|
||||
def load(self, path):
|
||||
self._tokenizer = Tokenizer.from_file(path)
|
||||
|
||||
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
|
||||
if isinstance(tokens, str):
|
||||
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
|
||||
return encoded.ids if out_ids else encoded.tokens
|
||||
elif isinstance(tokens, list):
|
||||
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
|
||||
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
|
||||
|
||||
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
|
||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._tokenizer.get_vocab_size()
|
||||
|
||||
@property
|
||||
def stop_ids(self) -> List[int]:
|
||||
stop_ids = []
|
||||
for token in self._control_tokens:
|
||||
stop_ids.append(self._tokenizer.token_to_id(token))
|
||||
return stop_ids
|
||||
|
||||
@property
|
||||
def bos_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<bos>")
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<eos>")
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<pad>")
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<|user|>")
|
||||
|
||||
@property
|
||||
def system_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<|system|>")
|
||||
|
|
@ -1,346 +0,0 @@
|
|||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import init
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional, Self, Tuple
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
"""
|
||||
Repeat k times along the dimension for attention heads.
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
n_rep (int): The number of repetitions.
|
||||
Returns:
|
||||
Tensor: The repeated tensor.
|
||||
"""
|
||||
|
||||
bs, slen, n_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
def get_rotary_emb(
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: torch.device = "cuda",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get the rotary embedding for the given dimension and maximum length.
|
||||
Args:
|
||||
dim (int): The dimension of the input.
|
||||
max_len (int): The maximum length of the input.
|
||||
base (float, optional): The base for the frequency. Defaults to 10000.
|
||||
device (torch.device, optional): The device to use. Defaults to "cuda".
|
||||
Returns:
|
||||
Tensor: The rotary embedding tensor.
|
||||
"""
|
||||
|
||||
theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim)
|
||||
t = torch.arange(0, max_len, device=device).float()
|
||||
freqs = torch.outer(t, theta)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
"""
|
||||
Apply rotary embedding to the input tensor.
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
freqs_cis (Tensor): The rotary embedding tensor.
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
dtype = x.dtype
|
||||
seq_len = x.size(1)
|
||||
|
||||
x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float())
|
||||
freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1)
|
||||
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.to(dtype)
|
||||
|
||||
def process_attention_mask(
|
||||
seq_mask: Tensor,
|
||||
start_pos: int = 0,
|
||||
seq_len: int = 0,
|
||||
is_causal: bool = False,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.float32
|
||||
) -> Tensor:
|
||||
"""
|
||||
Create attention mask for GQA
|
||||
Args:
|
||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||
start_pos (int): The starting position of the sequence.
|
||||
seq_len (int): The length of the sequence.
|
||||
is_causal (bool): Whether the attention is causal or not.
|
||||
device (torch.device): The device to use.
|
||||
Returns:
|
||||
Tensor: The attention mask tensor.
|
||||
"""
|
||||
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
# for single prompt chat
|
||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||
else:
|
||||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||
# if ndim > 2, it's 4D tensor
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
# (bsz, start_pos + seq_len)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
||||
# (bsz, seq_len, start_pos + seq_len)
|
||||
|
||||
if is_causal:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
|
||||
diagonal=start_pos
|
||||
)
|
||||
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
|
||||
expanded_mask = expanded_mask & causal_mask
|
||||
|
||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
# basic config
|
||||
vocab_size: Optional[int] = None
|
||||
n_dim: Optional[int] = None
|
||||
n_head: Optional[int] = None
|
||||
n_layer: Optional[int] = None
|
||||
m_len: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
d_ffn: Optional[int] = None
|
||||
|
||||
# GQA
|
||||
n_kvhead: Optional[int] = None
|
||||
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
with open(config_path, 'r') as f:
|
||||
config: dict = json.load(f)
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str) -> None:
|
||||
config_dict = asdict(self)
|
||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
init.normal_(self.weight, mean=0, std=0.006)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, n_dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(n_dim))
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
|
||||
norm = x * torch.rsqrt(mean_square + self.norm_eps)
|
||||
norm = norm.to(dtype)
|
||||
out = norm * self.weight
|
||||
return out
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_dim: int, d_ffn: int):
|
||||
super().__init__()
|
||||
self.up = Linear(n_dim, d_ffn)
|
||||
self.gate = Linear(n_dim, d_ffn)
|
||||
self.down = Linear(d_ffn, n_dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
|
||||
class GQA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_dim: int,
|
||||
n_head: int,
|
||||
n_kvhead: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert n_dim % n_head == 0
|
||||
assert n_head % n_kvhead == 0
|
||||
|
||||
self.head_dim = n_dim // n_head
|
||||
self.n_dim = n_dim
|
||||
self.n_heads = n_head
|
||||
self.n_kvheads = n_kvhead
|
||||
self.n_rep = n_head // n_kvhead
|
||||
|
||||
self.q_proj = Linear(n_dim, n_head * self.head_dim)
|
||||
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
||||
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
||||
self.o_proj = Linear(n_dim, n_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
mask: Tensor = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
start_pos: int = 0
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
k = self._split_heads(self.k_proj(x), self.n_kvheads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kvheads)
|
||||
q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis)
|
||||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
# copy to cache
|
||||
k_cache[:bsz, start_pos:start_pos + seq_len] = k
|
||||
v_cache[:bsz, start_pos:start_pos + seq_len] = v
|
||||
|
||||
# get cache
|
||||
k = k_cache[:bsz, :start_pos + seq_len]
|
||||
v = v_cache[:bsz, :start_pos + seq_len]
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
|
||||
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
|
||||
|
||||
return out
|
||||
|
||||
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps):
|
||||
super().__init__()
|
||||
self.attention = GQA(n_dim, n_head, n_kvhead)
|
||||
self.norm_attn = RMSNorm(n_dim, norm_eps)
|
||||
self.ffn = MLP(n_dim, d_ffn)
|
||||
self.norm_ffn = RMSNorm(n_dim, norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
start_pos: int = 0
|
||||
) -> Tensor:
|
||||
# attention
|
||||
attn_output = self.attention(
|
||||
self.norm_attn(x),
|
||||
freqs_cis,
|
||||
attention_mask,
|
||||
kv_cache,
|
||||
start_pos
|
||||
)
|
||||
x = attn_output + x
|
||||
|
||||
# feed forward
|
||||
x = self.ffn(self.norm_ffn(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderBlock(
|
||||
config.n_dim,
|
||||
config.n_head,
|
||||
config.d_ffn,
|
||||
config.n_kvhead,
|
||||
config.norm_eps
|
||||
)
|
||||
for _ in range(config.n_layer)
|
||||
])
|
||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
||||
init.normal_(self.embedding, mean=0, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor]=None,
|
||||
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||
start_pos: int = 0
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
seq_len = input_ids.size(-1)
|
||||
x = F.embedding(input_ids, self.embedding)
|
||||
|
||||
self.freq_cis = self.freq_cis.to(x.device)
|
||||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
||||
has_kvcache = persistent_key_values is not None
|
||||
|
||||
attn_mask = process_attention_mask(
|
||||
input_mask,
|
||||
start_pos=start_pos,
|
||||
seq_len=seq_len,
|
||||
is_causal=has_kvcache,
|
||||
device=x.device,
|
||||
dtype=x.dtype
|
||||
)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
kv_cache = persistent_key_values[i] if persistent_key_values else None
|
||||
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = F.linear(hidden_states, self.embedding)
|
||||
|
||||
return {
|
||||
"logits": logits,
|
||||
"hidden_states": hidden_states
|
||||
}
|
||||
|
||||
112
khaosz/model.py
112
khaosz/model.py
|
|
@ -1,112 +0,0 @@
|
|||
from torch import Tensor
|
||||
from typing import List, Tuple, Generator, Union
|
||||
|
||||
from khaosz.core.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
from khaosz.core.parameter import ParameterLoader
|
||||
|
||||
|
||||
class Khaosz:
|
||||
def __init__(self, model_dir: str):
|
||||
self.parameter = ParameterLoader.load(model_dir)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.parameter.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]]=None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = ChatGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
query,
|
||||
history=history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
def batch_generate(
|
||||
self,
|
||||
queries: List[str],
|
||||
histories: List[Tuple[str, str]]=None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> List[str]:
|
||||
generator = BatchGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
queries,
|
||||
histories=histories,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]]=None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||
stream_generator = StreamGenerator(self.parameter)
|
||||
return stream_generator.generate(
|
||||
query,
|
||||
history=history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
def retrieve_generate(
|
||||
self,
|
||||
retrieved,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = RetrievalGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
retrieved,
|
||||
query,
|
||||
history=history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
def text_generate(
|
||||
self,
|
||||
query: str,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = TextGenerator(self.parameter)
|
||||
|
||||
return generator.generate(
|
||||
query,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
encoder = EmbeddingEncoder(self.parameter)
|
||||
return encoder.encode(sentence)
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
from khaosz.trainer.data_util import DatasetLoader
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.strategy import (
|
||||
TrainConfig,
|
||||
CosineScheduleConfig,
|
||||
SgdrScheduleConfig,
|
||||
StrategyFactory,
|
||||
SchedulerFactory
|
||||
)
|
||||
from khaosz.trainer.trainer_callback import (
|
||||
TrainerCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
TrainerCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# strategy
|
||||
"DatasetLoader",
|
||||
"Trainer",
|
||||
"TrainConfig",
|
||||
"CosineScheduleConfig",
|
||||
"SgdrScheduleConfig",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
|
||||
# callback
|
||||
"TrainerCallback",
|
||||
"ProgressBarCallback",
|
||||
"CheckpointCallback",
|
||||
"TrainerCallback",
|
||||
"SchedulerCallback",
|
||||
]
|
||||
|
|
@ -1,326 +0,0 @@
|
|||
import torch
|
||||
import bisect
|
||||
import pickle as pkl
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from typing import Callable, List, Dict, Literal, Union
|
||||
|
||||
MutiSeg = Dict[str, List[Tensor]]
|
||||
Seg = Dict[str, Tensor]
|
||||
|
||||
def load_pkl_files(paths: List[str]):
|
||||
segments: MutiSeg = {}
|
||||
total_samples = 0
|
||||
|
||||
for path in paths:
|
||||
with open(path, "rb") as f:
|
||||
pkl_file: Seg = pkl.load(f)
|
||||
for key, value in pkl_file.items():
|
||||
if key not in segments:
|
||||
segments[key] = []
|
||||
segments[key].append(value)
|
||||
first_key = list(pkl_file.keys())[0]
|
||||
total_samples += pkl_file[first_key].numel()
|
||||
|
||||
return segments, total_samples
|
||||
|
||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
||||
seq_len = input_ids.size(0)
|
||||
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
|
||||
|
||||
iq = turn_id.view(seq_len, 1)
|
||||
ik = turn_id.view(1, seq_len)
|
||||
|
||||
# fix the causual attention mask(iq >= ik condition)
|
||||
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
||||
attention_mask = torch.tril(seq_mask)
|
||||
|
||||
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
||||
return attention_mask.unsqueeze(0)
|
||||
|
||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
is_bos_token = input_ids.eq(bos_token_id)
|
||||
is_eos_token = input_ids.eq(eos_token_id)
|
||||
|
||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||
token_markers[is_bos_token] = 1
|
||||
token_markers[is_eos_token] = -1
|
||||
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||
loss_mask = cumulative_markers - min_cumulative
|
||||
|
||||
return loss_mask.to(dtype=torch.bool)
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += len(seg)
|
||||
self.cum_lengths.append(total)
|
||||
self.total_length = total if segments else 0
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
|
||||
|
||||
result_segments = []
|
||||
|
||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
result_segments.append(self.segments[i][start:end])
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MutiSegmentFetcher:
|
||||
def __init__(self, muti_segments: MutiSeg):
|
||||
self.muti_keys = list(muti_segments.keys())
|
||||
self.muti_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in muti_segments.items()
|
||||
}
|
||||
|
||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.muti_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
|
||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
def __init__(self, chunk_size: int, device: str):
|
||||
super().__init__()
|
||||
self.segments: MutiSeg = {}
|
||||
self.chunk_size = chunk_size
|
||||
self.total_samples = 0
|
||||
self.device = device
|
||||
|
||||
def save(self, save_path: str):
|
||||
keys = list(self.segments.keys())
|
||||
if not keys:
|
||||
return
|
||||
|
||||
first_item = self.segments[keys[0]]
|
||||
segment_size = len(first_item)
|
||||
|
||||
for i in range(segment_size):
|
||||
formated_segment = {key: self.segments[key][i] for key in keys}
|
||||
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
|
||||
|
||||
|
||||
def load(self, load_path: Union[str, List[str]]):
|
||||
paths = [load_path] if isinstance(load_path, str) else load_path
|
||||
self.segments, self.total_samples = load_pkl_files(paths)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.total_samples // self.chunk_size > 0
|
||||
return self.total_samples // self.chunk_size
|
||||
|
||||
|
||||
class SeqDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size,
|
||||
device='cuda'
|
||||
):
|
||||
super().__init__(chunk_size, device)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx = index * self.chunk_size
|
||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
|
||||
|
||||
return {"input_ids": x, "target_ids": y}
|
||||
|
||||
|
||||
|
||||
class SftDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size,
|
||||
bos_token_id,
|
||||
eos_token_id,
|
||||
user_token_id,
|
||||
multi_turn=False,
|
||||
device='cuda'
|
||||
):
|
||||
super().__init__(chunk_size, device)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.user_token_id = user_token_id
|
||||
self.multi_turn = multi_turn
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx = index * self.chunk_size
|
||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
||||
|
||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
||||
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
||||
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
||||
|
||||
|
||||
class DpoDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int, device="cuda"):
|
||||
super().__init__(chunk_size, device)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
start_idx = index * self.chunk_size
|
||||
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
|
||||
|
||||
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
|
||||
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
|
||||
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
|
||||
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
|
||||
|
||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||
|
||||
|
||||
class PpoDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int, device="cuda"):
|
||||
super().__init__(chunk_size, device)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
|
||||
begin_idx = index * self.chunk_size
|
||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||
|
||||
|
||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
|
||||
actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
|
||||
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
|
||||
|
||||
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
||||
|
||||
|
||||
class DatasetLoader:
|
||||
@staticmethod
|
||||
def load(
|
||||
train_type: Literal["seq", "sft", "dpo"],
|
||||
load_path: Union[str, List[str]],
|
||||
max_len: int,
|
||||
device: str,
|
||||
**kwargs
|
||||
) -> BaseDataset:
|
||||
|
||||
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
|
||||
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
|
||||
"sft": lambda m_len, device: SftDataset(
|
||||
m_len,
|
||||
device=device,
|
||||
bos_token_id=kwargs.get("bos_token_id"),
|
||||
eos_token_id=kwargs.get("eos_token_id"),
|
||||
user_token_id=kwargs.get("user_token_id"),
|
||||
multi_turn=kwargs.get("multi_turn")
|
||||
),
|
||||
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
||||
}
|
||||
dataset = dataset_router[train_type](max_len, device)
|
||||
dataset.load(load_path)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class RandomSampler(Sampler[int]):
|
||||
def __init__(self, data_source, generator=None, seed=42):
|
||||
self.data_source = data_source
|
||||
self.seed = seed
|
||||
self.epoch = 0
|
||||
self.current_iter = 0
|
||||
self._indices = None
|
||||
|
||||
if generator is None:
|
||||
self.generator = torch.Generator()
|
||||
self.generator.manual_seed(seed)
|
||||
else:
|
||||
self.generator = generator
|
||||
|
||||
def _generate_indices(self):
|
||||
n = len(self.data_source)
|
||||
self._indices = torch.randperm(n, generator=self.generator).tolist()
|
||||
|
||||
def __iter__(self):
|
||||
n = len(self.data_source)
|
||||
|
||||
if self._indices is None:
|
||||
self._generate_indices()
|
||||
|
||||
start = self.current_iter % n
|
||||
for i in range(start, n):
|
||||
yield self._indices[i]
|
||||
self.current_iter += 1
|
||||
|
||||
self.epoch += 1
|
||||
self._indices = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'epoch': self.epoch,
|
||||
'current_iter': self.current_iter,
|
||||
'seed': self.seed,
|
||||
'generator_state': self.generator.get_state() if self.generator else None,
|
||||
'indices': self._indices
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.epoch = state_dict['epoch']
|
||||
self.current_iter = state_dict['current_iter']
|
||||
self.seed = state_dict['seed']
|
||||
|
||||
if self.generator and state_dict['generator_state'] is not None:
|
||||
self.generator.set_state(state_dict['generator_state'])
|
||||
|
||||
self._indices = state_dict['indices']
|
||||
|
|
@ -1,396 +0,0 @@
|
|||
import copy
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Any, Literal, Tuple, Callable, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
|
||||
input_mask = input_ids.ne(pad_token_id)
|
||||
logits = model(input_ids, input_mask)["logits"]
|
||||
log_probs = torch.log_softmax(logits, dim=-1)
|
||||
|
||||
shifted_log_probs = log_probs[:, :-1, :]
|
||||
shifted_input_ids = input_ids[:, 1:]
|
||||
shifted_response_mask = mask[:, 1:]
|
||||
|
||||
token_logprobs = torch.gather(
|
||||
shifted_log_probs,
|
||||
dim=-1,
|
||||
index=shifted_input_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
prompt_mask = input_mask[:, 1:]
|
||||
valid_mask = (prompt_mask & shifted_response_mask).float()
|
||||
|
||||
return (token_logprobs * valid_mask).sum(dim=-1)
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = model
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
return self.compute_loss(batch)
|
||||
|
||||
|
||||
class SeqStrategy(BaseStrategy):
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
B, L = input_ids.size()
|
||||
logits: Tensor = self.model(input_ids=input_ids)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.view(B * L, -1),
|
||||
target=target_ids.flatten()
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class SftStrategy(BaseStrategy):
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__(model)
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
||||
|
||||
ignore_index = -100
|
||||
B, L = input_ids.size()
|
||||
|
||||
logits: Tensor = self.model(
|
||||
input_ids=input_ids,
|
||||
input_mask=attn_mask
|
||||
)["logits"]
|
||||
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.view(B * L, -1),
|
||||
target=target_ids.flatten(),
|
||||
ignore_index=ignore_index
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class DpoStrategy(BaseStrategy):
|
||||
def __init__(self, model, pad_token_id, beta):
|
||||
super().__init__(model)
|
||||
ref_model = copy.deepcopy(self.model)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
|
||||
self.ref_model = ref_model
|
||||
self.pad_token_id = pad_token_id
|
||||
self.beta = beta
|
||||
|
||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
||||
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||
|
||||
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
|
||||
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
|
||||
|
||||
with torch.no_grad():
|
||||
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
|
||||
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
|
||||
|
||||
pi_log_ratio = log_pi_good - log_pi_bad
|
||||
ref_log_ratio = log_ref_good - log_ref_bad
|
||||
|
||||
ratio_diff = pi_log_ratio - ref_log_ratio
|
||||
|
||||
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
||||
return dpo_loss
|
||||
|
||||
|
||||
class PpoStrategy(BaseStrategy):
|
||||
def __init__(self, model, pad_token_id, epsilon):
|
||||
super().__init__(model)
|
||||
ref_model = copy.deepcopy(self.model)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
|
||||
self.ref_model = ref_model
|
||||
self.pad_token_id = pad_token_id
|
||||
self.epsilon = epsilon
|
||||
|
||||
def ppo_clip_loss_masked(
|
||||
self,
|
||||
log_probs: Tensor,
|
||||
old_log_probs: Tensor,
|
||||
advantages: Tensor,
|
||||
values: Tensor,
|
||||
returns: Tensor,
|
||||
mask: Tensor,
|
||||
clip_eps: float=0.2,
|
||||
):
|
||||
ratio = torch.exp(log_probs - old_log_probs)
|
||||
surr1 = ratio * advantages
|
||||
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
|
||||
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
|
||||
|
||||
value_loss = F.mse_loss(values.masked_select(mask),
|
||||
returns.masked_select(mask))
|
||||
|
||||
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
|
||||
entropy_loss = -entropy
|
||||
return policy_loss, value_loss, entropy_loss
|
||||
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
|
||||
def load(model, train_type, **kwargs):
|
||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||
"seq": lambda: SeqStrategy(model),
|
||||
"sft": lambda: SftStrategy(model),
|
||||
"dpo": lambda: DpoStrategy(
|
||||
model,
|
||||
kwargs.get("pad_token_id"),
|
||||
kwargs.get("dpo_beta")
|
||||
)
|
||||
}
|
||||
strategy = train_strategy[train_type]()
|
||||
return strategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
|
||||
strategy: BaseStrategy = field(
|
||||
default=None,
|
||||
metadata={"help": "Training strategy."}
|
||||
)
|
||||
dataset: Dataset = field(
|
||||
default=None,
|
||||
metadata={"help": "Dataset for training."}
|
||||
)
|
||||
optimizer: Optimizer = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer for training."}
|
||||
)
|
||||
checkpoint_dir: str = field(
|
||||
default="./checkpoint",
|
||||
metadata={"help": "Checkpoint directory."}
|
||||
)
|
||||
n_epoch: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of epochs for training."}
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Batch size for training."}
|
||||
)
|
||||
checkpoint_interval: int = field(
|
||||
default=5000,
|
||||
metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
accumulation_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
random_seed: int = field(
|
||||
default=3407,
|
||||
metadata={"help": "Random seed."}
|
||||
)
|
||||
|
||||
def get_kwargs(self)-> Dict[str, Any]:
|
||||
config_dict = asdict(self)
|
||||
return {k: v for k, v in config_dict.items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduleConfig(ABC):
|
||||
schedule_type: str = field(
|
||||
default="cosine",
|
||||
metadata={
|
||||
"help": "Type of learning rate schedule.",
|
||||
"choices": ["cosine", "sgdr"]
|
||||
}
|
||||
)
|
||||
warmup_steps: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Number of warmup steps."}
|
||||
)
|
||||
min_rate: float = field(
|
||||
default=0.05,
|
||||
metadata={"help": "Minimum learning rate multiplier."}
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
||||
if not 0 <= self.min_rate <= 1:
|
||||
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CosineScheduleConfig(ScheduleConfig):
|
||||
total_steps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Total training steps for cosine schedule."}
|
||||
)
|
||||
schedule_type: Literal["cosine"] = "cosine"
|
||||
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
if self.total_steps is None:
|
||||
raise ValueError("total_steps must be specified for cosine schedule")
|
||||
|
||||
return {
|
||||
"schedule_type": self.schedule_type,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||
"min_rate": self.min_rate
|
||||
}
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SgdrScheduleConfig(ScheduleConfig):
|
||||
cycle_length: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Length of the first cycle in steps."}
|
||||
)
|
||||
t_mult: int = field(
|
||||
default=2,
|
||||
metadata={"help": "Multiplier for cycle length growth."}
|
||||
)
|
||||
schedule_type: Literal["sgdr"] = "sgdr"
|
||||
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"schedule_type": self.schedule_type,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"cycle_length": self.cycle_length,
|
||||
"min_rate": self.min_rate,
|
||||
"t_mult": self.t_mult
|
||||
}
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
if self.cycle_length <= 0:
|
||||
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
||||
if self.t_mult < 1:
|
||||
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
||||
|
||||
|
||||
class SchedulerFactory:
|
||||
"""Factory for creating learning rate schedule functions."""
|
||||
|
||||
@staticmethod
|
||||
def get_sgdr_schedule(
|
||||
warmup_steps: int,
|
||||
cycle_length: int,
|
||||
min_rate: float = 0.05,
|
||||
t_mult: int = 2
|
||||
) -> Callable[[int], float]:
|
||||
"""
|
||||
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
|
||||
|
||||
Args:
|
||||
warmup_steps: Number of warmup steps
|
||||
cycle_length: Length of the first cycle
|
||||
min_rate: Minimum learning rate multiplier
|
||||
t_mult: Cycle length multiplier
|
||||
|
||||
Returns:
|
||||
Schedule function that takes current step and returns LR multiplier
|
||||
"""
|
||||
|
||||
def sgdr_schedule(current_step: int) -> float:
|
||||
# Warmup phase
|
||||
if current_step < warmup_steps:
|
||||
return max(min_rate, current_step / warmup_steps)
|
||||
|
||||
# SGDR phase
|
||||
steps_since_warmup = current_step - warmup_steps
|
||||
|
||||
# Find current cycle and position within cycle
|
||||
cycle_start = 0
|
||||
current_cycle_length = cycle_length
|
||||
cycle_index = 0
|
||||
|
||||
while steps_since_warmup >= cycle_start + current_cycle_length:
|
||||
cycle_start += current_cycle_length
|
||||
current_cycle_length *= t_mult
|
||||
cycle_index += 1
|
||||
|
||||
position_in_cycle = steps_since_warmup - cycle_start
|
||||
progress = position_in_cycle / current_cycle_length
|
||||
|
||||
# Cosine annealing within cycle
|
||||
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
|
||||
|
||||
return sgdr_schedule
|
||||
|
||||
@staticmethod
|
||||
def get_cosine_schedule(
|
||||
warmup_steps: int,
|
||||
lr_decay_steps: int,
|
||||
min_rate: float = 0.05
|
||||
) -> Callable[[int], float]:
|
||||
"""
|
||||
Create cosine decay schedule with warmup.
|
||||
|
||||
Args:
|
||||
warmup_steps: Number of warmup steps
|
||||
lr_decay_steps: Number of steps for cosine decay after warmup
|
||||
min_rate: Minimum learning rate multiplier
|
||||
|
||||
Returns:
|
||||
Schedule function that takes current step and returns LR multiplier
|
||||
"""
|
||||
|
||||
def cosine_schedule(current_step: int) -> float:
|
||||
if current_step < warmup_steps:
|
||||
# Linear warmup
|
||||
return max(min_rate, current_step / warmup_steps)
|
||||
else:
|
||||
# Cosine decay
|
||||
decay_progress = (current_step - warmup_steps) / lr_decay_steps
|
||||
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
|
||||
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
|
||||
|
||||
return cosine_schedule
|
||||
|
||||
@staticmethod
|
||||
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
|
||||
kwargs = scedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
|
||||
if schedule_type == "cosine":
|
||||
return SchedulerFactory.get_cosine_schedule(**kwargs)
|
||||
elif schedule_type == "sgdr":
|
||||
return SchedulerFactory.get_sgdr_schedule(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported schedule type: {schedule_type}")
|
||||
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
import logging
|
||||
from typing import Optional, List, cast
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from khaosz.core import ModelParameter, Checkpoint
|
||||
from khaosz.trainer.data_util import RandomSampler
|
||||
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
||||
from khaosz.trainer.trainer_callback import (
|
||||
TrainerCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
parameter: ModelParameter,
|
||||
train_config: TrainConfig,
|
||||
schedule_config: ScheduleConfig,
|
||||
callbacks: Optional[List[TrainerCallback]] = None
|
||||
):
|
||||
self.parameter = parameter
|
||||
self.train_config = train_config
|
||||
self.schedule_config = schedule_config
|
||||
self.callbacks = callbacks or self._get_default_callbacks()
|
||||
|
||||
def _get_default_callbacks(self) -> List[TrainerCallback]:
|
||||
return [
|
||||
ProgressBarCallback(),
|
||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||
GradientClippingCallback(),
|
||||
SchedulerCallback(self.schedule_config),
|
||||
]
|
||||
|
||||
def _set_train_kwargs(self, kwargs: dict):
|
||||
seed = self.train_config.random_seed
|
||||
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
|
||||
optim = self.train_config.optimizer
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
|
||||
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
model=self.parameter.model,
|
||||
tokenizer=self.parameter.tokenizer,
|
||||
config=self.parameter.config,
|
||||
sampler_state=None,
|
||||
optim_state=None,
|
||||
loss_list=[]
|
||||
)
|
||||
|
||||
sampler_state = checkpoint.sampler_state
|
||||
optim_state = checkpoint.optim_state
|
||||
|
||||
if sampler_state:
|
||||
sampler.load_state_dict(sampler_state)
|
||||
|
||||
if optim_state:
|
||||
optim.load_state_dict(optim_state)
|
||||
|
||||
checkpoint.optim_state = optim.state_dict()
|
||||
checkpoint.sampler_state = sampler.state_dict()
|
||||
|
||||
dataloader = DataLoader(
|
||||
self.train_config.dataset,
|
||||
batch_size=self.train_config.batch_size,
|
||||
sampler=sampler
|
||||
)
|
||||
|
||||
kwargs["dataloader"] = dataloader
|
||||
kwargs["optimizer"] = self.train_config.optimizer
|
||||
kwargs["epoch"] = sampler.epoch
|
||||
kwargs["current_iter"] = sampler.current_iter
|
||||
kwargs["sampler"] = sampler
|
||||
kwargs["checkpoint"] = checkpoint
|
||||
|
||||
def _call_callbacks(self, method_name: str, **kwargs):
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
if method:
|
||||
method(self, **kwargs)
|
||||
|
||||
def train(
|
||||
self,
|
||||
checkpoint: Optional[Checkpoint] = None
|
||||
) -> Checkpoint:
|
||||
|
||||
# train
|
||||
train_kwargs = {
|
||||
'checkpoint': checkpoint,
|
||||
'dataloader': None,
|
||||
'optimizer': None,
|
||||
'sampler': None,
|
||||
'epoch': 0,
|
||||
'current_iter': 0,
|
||||
'loss': 0.0,
|
||||
}
|
||||
|
||||
self._set_train_kwargs(train_kwargs)
|
||||
self._call_callbacks('on_train_begin', **train_kwargs)
|
||||
|
||||
dataloader = train_kwargs['dataloader']
|
||||
checkpoint = train_kwargs['checkpoint']
|
||||
start_epoch = train_kwargs['epoch']
|
||||
|
||||
try:
|
||||
self.parameter.model.train()
|
||||
for epoch in range(start_epoch, self.train_config.n_epoch):
|
||||
# epoch
|
||||
train_kwargs["epoch"] = epoch
|
||||
self._call_callbacks('on_epoch_begin', **train_kwargs)
|
||||
for batch in dataloader:
|
||||
|
||||
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
|
||||
# step
|
||||
self._call_callbacks('on_step_begin', **train_kwargs)
|
||||
self.train_config.optimizer.step()
|
||||
self.train_config.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', **train_kwargs)
|
||||
|
||||
# batch
|
||||
self._call_callbacks('on_batch_begin', **train_kwargs)
|
||||
loss = self.train_config.strategy(batch)
|
||||
train_kwargs["loss"] = loss.item()
|
||||
train_kwargs["current_iter"] += 1
|
||||
loss.backward()
|
||||
|
||||
self._call_callbacks('on_batch_end', **train_kwargs)
|
||||
|
||||
self._call_callbacks('on_epoch_end', **train_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks('on_train_end', **train_kwargs)
|
||||
return checkpoint
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
import os
|
||||
import torch.optim as optim
|
||||
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from typing import Optional, cast, TYPE_CHECKING
|
||||
from khaosz.core.parameter import Checkpoint
|
||||
from khaosz.trainer.data_util import RandomSampler
|
||||
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
|
||||
|
||||
class TrainerCallback:
|
||||
"""
|
||||
Callback interface for trainer.
|
||||
and we use '_' to ignore unused parameters.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the beginning of training. """
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the end of training. """
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the beginning of each epoch. """
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the end of each epoch. """
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_batch_begin(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the beginning of each batch. """
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the end of each batch. """
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the beginning of each step. """
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_step_end(self, trainer: 'Trainer', **kwargs):
|
||||
""" Called at the end of each step."""
|
||||
_ = trainer, kwargs
|
||||
|
||||
|
||||
class ProgressBarCallback(TrainerCallback):
|
||||
"""
|
||||
Progress bar callback for trainer.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.progress_bar: tqdm = None
|
||||
|
||||
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
||||
epoch = kwargs.get('epoch')
|
||||
dataloader = kwargs.get('dataloader')
|
||||
self.progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
|
||||
dynamic_ncols=True
|
||||
)
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer
|
||||
loss = kwargs.get('loss')
|
||||
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
|
||||
self.progress_bar.set_postfix({
|
||||
"loss": f"{loss:.4f}",
|
||||
"lr": f"{optimizer.param_groups[-1]['lr']:.2e}"
|
||||
})
|
||||
self.progress_bar.update(1)
|
||||
|
||||
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer, kwargs
|
||||
if self.progress_bar:
|
||||
self.progress_bar.close()
|
||||
|
||||
|
||||
class CheckpointCallback(TrainerCallback):
|
||||
"""
|
||||
Checkpoint callback for trainer.
|
||||
"""
|
||||
def __init__(self, checkpoint_interval: int):
|
||||
self.checkpoint_interval = checkpoint_interval
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@staticmethod
|
||||
def _save_checkpoint(trainer: 'Trainer', **kwargs):
|
||||
current_iter = kwargs.get('current_iter')
|
||||
random_sampler = cast(RandomSampler, kwargs.get('sampler'))
|
||||
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
|
||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
|
||||
checkpoint.sampler_state = random_sampler.state_dict()
|
||||
checkpoint.optim_state = optimizer.state_dict()
|
||||
|
||||
checkpoint.save(save_path)
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
current_iter = kwargs.get('current_iter')
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
loss = kwargs.get('loss')
|
||||
checkpoint.loss_list.append(loss)
|
||||
|
||||
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||
CheckpointCallback._save_checkpoint(trainer, **kwargs)
|
||||
self.last_ckpt_iter = current_iter
|
||||
|
||||
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||
current_iter = kwargs.get('current_iter')
|
||||
if current_iter != self.last_ckpt_iter:
|
||||
CheckpointCallback._save_checkpoint(trainer, **kwargs)
|
||||
self.last_ckpt_iter = current_iter
|
||||
|
||||
|
||||
class GradientClippingCallback(TrainerCallback):
|
||||
"""
|
||||
Gradient clipping callback for trainer.
|
||||
"""
|
||||
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||
_ = kwargs
|
||||
clip_grad_norm_(
|
||||
trainer.parameter.model.parameters(),
|
||||
trainer.train_config.max_grad_norm
|
||||
)
|
||||
|
||||
|
||||
class SchedulerCallback(TrainerCallback):
|
||||
"""
|
||||
Scheduler callback for trainer.
|
||||
"""
|
||||
def __init__(self, schedule_config: ScheduleConfig):
|
||||
self.schedule_config = schedule_config
|
||||
self.scheduler: Optional[LambdaLR] = None
|
||||
self.current_iter = 0
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
self.current_iter = kwargs.get('current_iter')
|
||||
|
||||
for group in trainer.train_config.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
self.schedule_config.validate()
|
||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||
self.schedule_config
|
||||
)
|
||||
|
||||
self.scheduler = LambdaLR(
|
||||
trainer.train_config.optimizer,
|
||||
lambda_scheduler_fn,
|
||||
last_epoch=self.current_iter - 1
|
||||
)
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer, kwargs
|
||||
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
self.current_iter += 1
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
import torch
|
||||
import sqlite3
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class Retriever:
|
||||
def __init__(self, db_path=None):
|
||||
self.data: Dict[str, Tensor] = {}
|
||||
self.embedding_cache: Tensor = None
|
||||
self.is_caculated: bool = False
|
||||
|
||||
if db_path is not None:
|
||||
self.load(db_path)
|
||||
|
||||
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
|
||||
if not self.data:
|
||||
return []
|
||||
|
||||
query = query.flatten().unsqueeze(1) # [dim, 1]
|
||||
norm_embeddings = self._embeddings.to(
|
||||
device=query.device,
|
||||
dtype=query.dtype
|
||||
) # [n_vectors, dim]
|
||||
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
|
||||
|
||||
top_k = min(top_k, len(self.data))
|
||||
indices = sim_scores.topk(top_k).indices
|
||||
keys = list(self.data.keys())
|
||||
|
||||
return [(keys[i], sim_scores[i].item()) for i in indices]
|
||||
|
||||
def add_vector(self, key: str, vector_data: Tensor):
|
||||
self.is_caculated = False
|
||||
self.data[key] = vector_data.flatten().float().cpu()
|
||||
|
||||
def delete_vector(self, key: str):
|
||||
self.is_caculated = False
|
||||
self.data.pop(key, None)
|
||||
|
||||
def save(self, db_path):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
self._init_db(cursor)
|
||||
cursor.execute('DELETE FROM vectors')
|
||||
|
||||
for item, vec in self.data.items():
|
||||
vec_bytes = vec.numpy().tobytes()
|
||||
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
|
||||
(item, vec_bytes))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def load(self, db_path):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
self._init_db(cursor)
|
||||
cursor.execute('SELECT key, vector FROM vectors')
|
||||
rows = cursor.fetchall()
|
||||
self.data = {}
|
||||
|
||||
for row in rows:
|
||||
key, vec_bytes = row
|
||||
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
|
||||
vec = torch.from_numpy(vec_numpy)
|
||||
self.data[key] = vec
|
||||
|
||||
conn.close()
|
||||
|
||||
def _init_db(self,cursor: sqlite3.Cursor):
|
||||
# Create table if not exists (in case loading from a new database)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS vectors (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key TEXT UNIQUE NOT NULL,
|
||||
vector BLOB NOT NULL
|
||||
)''')
|
||||
|
||||
@property
|
||||
def _embeddings(self) -> Tensor:
|
||||
if not self.is_caculated:
|
||||
embeddings = torch.stack(list(self.data.values()))
|
||||
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
|
||||
self.embedding_cache = norm_embeddings
|
||||
|
||||
return self.embedding_cache
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue