docs: 修正 assets/docs/ 类图、数据流、参数文档及贡献指南
- design.md: 新增 ProtocolHandler/OpenAIHandler/AnthropicHandler 等缺失类 - design.md: 新增 Template Method、Storage 设计模式 - dataflow.md: 修正 GQA/MLA 为独立条目,补充 JSON 存储后端 - params.md: 标注 label_smoothing CLI 默认与 strategy 默认差异 - introduction.md: 修正 max_tokens 默认值 1024→2048 - CONTRIBUTING.md: 重写(纯 Python 无 conda、补充 CI 步骤与常见问题) - .github/PULL_REQUEST_TEMPLATE.md: 修正 lint 命令,去除多余注释要求 - .github/ISSUE_TEMPLATE/bug_report.md: 修正 label(enhancement→bug)
This commit is contained in:
parent
e12f1a7ee5
commit
c169659611
|
|
@ -2,7 +2,7 @@
|
||||||
name: Bug report
|
name: Bug report
|
||||||
about: Create a report to help us improve
|
about: Create a report to help us improve
|
||||||
title: "[BUG]"
|
title: "[BUG]"
|
||||||
labels: enhancement
|
labels: bug
|
||||||
assignees: ''
|
assignees: ''
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,9 @@ Please delete options that are not relevant.
|
||||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
|
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
|
||||||
|
|
||||||
## Checklist:
|
## Checklist:
|
||||||
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check --fix .`)
|
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
|
||||||
- [ ] I have performed a self-review of my own code
|
- [ ] I have performed a self-review of my own code
|
||||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
- [ ] Code is self-documenting (no unnecessary comments)
|
||||||
- [ ] I have made corresponding changes to the documentation
|
- [ ] I have made corresponding changes to the documentation
|
||||||
- [ ] My changes generate no new warnings
|
- [ ] My changes generate no new warnings
|
||||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||||
|
|
|
||||||
128
CONTRIBUTING.md
128
CONTRIBUTING.md
|
|
@ -1,68 +1,100 @@
|
||||||
# Contributing to AstrAI
|
# Contributing to AstrAI
|
||||||
|
|
||||||
Thank you for your interest in contributing to AstrAI! This document provides guidelines and steps for contributing.
|
Thank you for your interest in contributing! This document provides step-by-step guidelines.
|
||||||
|
|
||||||
## How to Contribute
|
## Quick Start
|
||||||
|
|
||||||
### Reporting Issues
|
```bash
|
||||||
If you encounter a bug or have a feature request, please open an issue on GitHub. Include as much detail as possible:
|
git clone https://github.com/your-username/AstrAI.git
|
||||||
- A clear description of the problem or request.
|
cd AstrAI
|
||||||
- Steps to reproduce (for bugs).
|
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
||||||
- Your environment (Python version, OS, etc.).
|
```
|
||||||
|
|
||||||
### Submitting Changes
|
## Before You Commit
|
||||||
1. **Fork** the repository.
|
|
||||||
2. **Clone** your fork:
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/your-username/AstrAI.git
|
|
||||||
cd AstrAI
|
|
||||||
```
|
|
||||||
3. **Create a feature branch**:
|
|
||||||
```bash
|
|
||||||
git checkout -b feature/your-feature-name
|
|
||||||
```
|
|
||||||
4. **Make your changes**. Follow the code style guidelines below.
|
|
||||||
5. **Commit your changes** with a descriptive commit message:
|
|
||||||
```bash
|
|
||||||
git commit -m "Add: brief description of the change"
|
|
||||||
```
|
|
||||||
6. **Push** to your fork:
|
|
||||||
```bash
|
|
||||||
git push origin feature/your-feature-name
|
|
||||||
```
|
|
||||||
7. **Open a Pull Request** (PR) against the `main` branch of the upstream repository.
|
|
||||||
|
|
||||||
## Code Style
|
Run the following checks **in order** — CI will reject if any fail.
|
||||||
|
|
||||||
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
|
### 1. Format
|
||||||
|
|
||||||
- Run Ruff to format and lint (requires conda environment `nlp`):
|
```bash
|
||||||
```bash
|
ruff format .
|
||||||
conda run -n nlp ruff format .
|
```
|
||||||
conda run -n nlp ruff check --fix .
|
|
||||||
```
|
|
||||||
- The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`).
|
|
||||||
|
|
||||||
## Testing
|
> **Note**: `ruff format` may rename parameters (e.g. `mask` → `attn_mask`).
|
||||||
|
> Always review the diff after formatting.
|
||||||
|
|
||||||
If you add or modify functionality, please include appropriate tests.
|
### 2. Import sorting
|
||||||
|
|
||||||
- Run the test suite with:
|
```bash
|
||||||
```bash
|
ruff check . --select I
|
||||||
conda run -n nlp python -u -m pytest
|
```
|
||||||
```
|
|
||||||
- Ensure all tests pass before submitting your PR.
|
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff check . --select I --fix .
|
||||||
|
ruff format . # re-format after fix
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Run tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -u -m pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
|
||||||
|
|
||||||
|
### 4. (Optional) Full pre-commit check
|
||||||
|
|
||||||
|
If you have Git Bash available:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/pre_commit.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs format check, import sort check, and tests in one go.
|
||||||
|
|
||||||
|
## Commit Style
|
||||||
|
|
||||||
|
```
|
||||||
|
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
|
||||||
|
|
||||||
|
- bullet point body (each ~60 chars)
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
|
||||||
|
- **Subject line** ends with no period.
|
||||||
|
- **Body** uses bullet points starting with `-`.
|
||||||
|
- No `(scope)` parentheses.
|
||||||
|
|
||||||
|
## Common Issues
|
||||||
|
|
||||||
|
| Problem | Cause | Fix |
|
||||||
|
|---------|-------|-----|
|
||||||
|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
|
||||||
|
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
|
||||||
|
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
|
||||||
|
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
|
||||||
|
|
||||||
|
## Submitting Changes
|
||||||
|
|
||||||
|
1. Fork the repo.
|
||||||
|
2. Create a feature branch: `git checkout -b feat/my-feature`
|
||||||
|
3. Make changes following the steps above.
|
||||||
|
4. Commit with the commit style above.
|
||||||
|
5. Push: `git push origin feat/my-feature`
|
||||||
|
6. Open a Pull Request against `main`.
|
||||||
|
|
||||||
## Code Review
|
## Code Review
|
||||||
|
|
||||||
All submissions will be reviewed. We may request changes or discuss alternatives. Please be responsive to feedback.
|
- All PRs are reviewed. We may request changes.
|
||||||
|
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
|
||||||
|
- Ensure all tests pass.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
By contributing, you agree that your contributions will be licensed under the same [GPL-3.0 License](LICENSE) that covers the project.
|
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
If you have any questions, feel free to ask in the [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
|
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
|
||||||
|
|
||||||
Happy contributing!
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
AstrAI adopts a modular design with the following main components:
|
AstrAI adopts a modular design with the following main components:
|
||||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, storage backends
|
||||||
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
||||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||||
|
|
@ -71,7 +71,8 @@ flowchart LR
|
||||||
- **`BaseDataset`**: Abstract base class for windowed sequence sampling
|
- **`BaseDataset`**: Abstract base class for windowed sequence sampling
|
||||||
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
|
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
|
||||||
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
|
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
|
||||||
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO)
|
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen"/"rejected"` (DPO), `"prompts"/"responses"/"masks"/"rewards"` (GRPO)
|
||||||
|
- Storage backends: HDF5 (`.h5`) or JSON (`.json`/`.jsonl`), auto-detected by `detect_format()`
|
||||||
|
|
||||||
#### 2.2 Sampler (`sampler.py`)
|
#### 2.2 Sampler (`sampler.py`)
|
||||||
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
|
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
|
||||||
|
|
@ -85,8 +86,9 @@ flowchart LR
|
||||||
- RoPE position encoding, optional weight tying
|
- RoPE position encoding, optional weight tying
|
||||||
|
|
||||||
#### 3.2 Submodules (`module.py`)
|
#### 3.2 Submodules (`module.py`)
|
||||||
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
- **`DecoderBlock`**: Pre-LN (norm→attention→residual, norm→MLP→residual), uses `AttnFactory` / `FFNFactory`
|
||||||
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
- **`GQA`**: Grouped Query Attention (q_heads ÷ kv_heads = n_rep)
|
||||||
|
- **`MLA`**: Multi-head Latent Attention with KV compression (kv_lora_rank)
|
||||||
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
||||||
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
|
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
|
||||||
- **`RMSNorm`**: Layer normalization
|
- **`RMSNorm`**: Layer normalization
|
||||||
|
|
@ -107,10 +109,12 @@ on_train_begin
|
||||||
for each accumulation window of batches: ← step phase
|
for each accumulation window of batches: ← step phase
|
||||||
on_step_begin
|
on_step_begin
|
||||||
for each batch in window: ← batch phase
|
for each batch in window: ← batch phase
|
||||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
on_batch_begin → strategy(batch) → loss
|
||||||
|
(loss / window_size).backward() → on_batch_end
|
||||||
iteration += 1
|
iteration += 1
|
||||||
on_step_end
|
on_step_end
|
||||||
optimizer.step() → zero_grad
|
optimizer.step() → zero_grad
|
||||||
|
scheduler.step() ← per step, not per batch
|
||||||
|
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
|
|
@ -120,7 +124,7 @@ Key points:
|
||||||
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
|
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
|
||||||
- `on_batch_*` fires every batch, wrapping loss computation
|
- `on_batch_*` fires every batch, wrapping loss computation
|
||||||
- `GradientClippingCallback` fires on `on_step_end`
|
- `GradientClippingCallback` fires on `on_step_end`
|
||||||
- LR scheduler steps inline (no `SchedulerCallback` class)
|
- LR scheduler steps inline (no `SchedulerCallback` class), once per optimizer step
|
||||||
|
|
||||||
#### 4.3 Strategy (`strategy.py`)
|
#### 4.3 Strategy (`strategy.py`)
|
||||||
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
||||||
|
|
@ -167,7 +171,7 @@ Background thread runs continuously:
|
||||||
|
|
||||||
### 6. Tokenizer Module
|
### 6. Tokenizer Module
|
||||||
|
|
||||||
- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template`
|
- **`AutoTokenizer`**: Wraps HuggingFace `tokenizers.Tokenizer` (not `transformers`); `encode`/`decode`/`apply_chat_template`
|
||||||
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
|
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
|
||||||
|
|
||||||
### 7. Factory & Parallel
|
### 7. Factory & Parallel
|
||||||
|
|
@ -195,14 +199,14 @@ Background thread runs continuously:
|
||||||
- Computes task-specific loss (cross-entropy, DPO, GRPO)
|
- Computes task-specific loss (cross-entropy, DPO, GRPO)
|
||||||
|
|
||||||
5. **Backward & Accumulation**
|
5. **Backward & Accumulation**
|
||||||
- `loss = raw_loss / accumulation_steps`
|
- `stand_loss = loss / step_batch_nums` (divide by actual batch count in this window)
|
||||||
- `loss.backward()` accumulates gradients
|
- `stand_loss.backward()` accumulates gradients
|
||||||
- Every `accumulation_steps` batches: `optimizer.step()` → `zero_grad()`
|
- Every `accumulation_steps` batches: `optimizer.step()` → `zero_grad()`
|
||||||
- Every batch: `scheduler.step()` updates learning rate
|
- Every optimizer step: `scheduler.step()` updates learning rate
|
||||||
|
|
||||||
6. **Checkpoint**
|
6. **Checkpoint**
|
||||||
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
|
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
|
||||||
- Does NOT save optimizer/scheduler state (resume resets those)
|
- Does NOT save optimizer/scheduler state by default; `Checkpoint.extra` or `save_extra_fn` can store arbitrary additional data
|
||||||
|
|
||||||
## Inference Data Flow — Detailed Steps
|
## Inference Data Flow — Detailed Steps
|
||||||
|
|
||||||
|
|
@ -230,8 +234,8 @@ Background thread runs continuously:
|
||||||
|
|
||||||
## Checkpoint & Serialization
|
## Checkpoint & Serialization
|
||||||
|
|
||||||
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted.
|
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata + optional extras. Optimizer/scheduler state is NOT persisted by default but can be stored via `extra`.
|
||||||
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
|
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
|
||||||
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
|
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
> Document Update Time: 2026-05-15
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,12 @@ classDiagram
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+bool use_qk_norm
|
+bool use_qk_norm
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
|
+str attn_type
|
||||||
|
+str ffn_type
|
||||||
|
+int n_routed_experts
|
||||||
|
+int n_shared_experts
|
||||||
|
+int n_activated_experts
|
||||||
|
+str moe_topk_method
|
||||||
+load(config_path) ModelConfig
|
+load(config_path) ModelConfig
|
||||||
+save(config_path)
|
+save(config_path)
|
||||||
}
|
}
|
||||||
|
|
@ -42,7 +48,7 @@ classDiagram
|
||||||
+int ckpt_interval
|
+int ckpt_interval
|
||||||
+int random_seed
|
+int random_seed
|
||||||
+int num_workers
|
+int num_workers
|
||||||
+int prefetch_factor
|
+Optional[int] prefetch_factor
|
||||||
+bool pin_memory
|
+bool pin_memory
|
||||||
+int nprocs
|
+int nprocs
|
||||||
+str backend
|
+str backend
|
||||||
|
|
@ -118,8 +124,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
+int epoch
|
+int start_epoch
|
||||||
+int iter
|
+int start_iter
|
||||||
}
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
|
|
@ -135,6 +141,7 @@ classDiagram
|
||||||
+dict state_dict
|
+dict state_dict
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
|
+dict extra
|
||||||
+save(save_dir)
|
+save(save_dir)
|
||||||
+load(save_dir) Checkpoint
|
+load(save_dir) Checkpoint
|
||||||
}
|
}
|
||||||
|
|
@ -158,15 +165,15 @@ classDiagram
|
||||||
+ModuleList layers
|
+ModuleList layers
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
+Linear lm_head
|
+Linear lm_head
|
||||||
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor
|
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
|
||||||
+load_state_dict(state_dict)
|
+load_state_dict(state_dict)
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
class DecoderBlock {
|
class DecoderBlock {
|
||||||
+GQA attention
|
+nn.Module attention # GQA or MLA via AttnFactory
|
||||||
+RMSNorm input_norm
|
+RMSNorm input_norm
|
||||||
+MLP mlp
|
+nn.Module mlp # MLP or DeepSeekMoE via FFNFactory
|
||||||
+RMSNorm post_attention_norm
|
+RMSNorm post_attention_norm
|
||||||
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
|
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -175,8 +182,12 @@ classDiagram
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
|
+int n_rep
|
||||||
|
+bool use_qk_norm
|
||||||
|
+bool use_gated_attention
|
||||||
+Linear q_proj, k_proj, v_proj, o_proj
|
+Linear q_proj, k_proj, v_proj, o_proj
|
||||||
+RMSNorm q_norm, k_norm
|
+Linear gate # only if use_gated_attention
|
||||||
|
+RMSNorm q_norm, k_norm # only if use_qk_norm
|
||||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -187,8 +198,11 @@ classDiagram
|
||||||
+int kv_lora_rank
|
+int kv_lora_rank
|
||||||
+int qk_nope_head_dim
|
+int qk_nope_head_dim
|
||||||
+int qk_rope_head_dim
|
+int qk_rope_head_dim
|
||||||
|
+int n_rep
|
||||||
|
+bool use_gated_attention
|
||||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||||
+Linear o_proj
|
+Linear o_proj
|
||||||
|
+Linear gate # only if use_gated_attention
|
||||||
+RMSNorm kv_norm
|
+RMSNorm kv_norm
|
||||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -198,6 +212,25 @@ classDiagram
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class DeepSeekMoE {
|
||||||
|
+int n_routed_experts
|
||||||
|
+int n_shared_experts
|
||||||
|
+int n_activated_experts
|
||||||
|
+str topk_method
|
||||||
|
+Linear router
|
||||||
|
+ModuleList shared_experts
|
||||||
|
+ModuleList routed_experts
|
||||||
|
+forward(x) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class AttnFactory {
|
||||||
|
+create(attn_type, **kwargs) nn.Module
|
||||||
|
}
|
||||||
|
|
||||||
|
class FFNFactory {
|
||||||
|
+create(ffn_type, dim, dim_ffn, **kwargs) nn.Module
|
||||||
|
}
|
||||||
|
|
||||||
class RMSNorm {
|
class RMSNorm {
|
||||||
+Parameter weight
|
+Parameter weight
|
||||||
+float norm_eps
|
+float norm_eps
|
||||||
|
|
@ -206,7 +239,7 @@ classDiagram
|
||||||
|
|
||||||
class Linear {
|
class Linear {
|
||||||
+Parameter weight
|
+Parameter weight
|
||||||
+Parameter bias
|
+Optional[Parameter] bias # only if bias=True
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -365,7 +398,7 @@ classDiagram
|
||||||
|
|
||||||
class GradientClippingCallback {
|
class GradientClippingCallback {
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+on_step_begin(context)
|
+on_step_end(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
|
|
@ -410,15 +443,24 @@ classDiagram
|
||||||
+shutdown()
|
+shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
class InferenceScheduler {
|
class Executor {
|
||||||
+nn.Module model
|
+AutoModel model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
|
+KVCache page_cache
|
||||||
|
+execute_prefill(tasks, prompt_len, start_pos)
|
||||||
|
+execute_decode(tasks) List[int]
|
||||||
|
}
|
||||||
|
|
||||||
|
class InferenceScheduler {
|
||||||
+KVCache _page_cache
|
+KVCache _page_cache
|
||||||
|
+Executor _executor
|
||||||
|
+TaskManager _task_mgr
|
||||||
|
+bool _running
|
||||||
|
+Thread _loop_thread
|
||||||
+int max_batch_size
|
+int max_batch_size
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+int max_prompt_len
|
+int max_prompt_len
|
||||||
+int page_size
|
+int page_size
|
||||||
+TaskManager _task_mgr
|
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
|
|
@ -428,8 +470,8 @@ classDiagram
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
+int _free_mask
|
+int _free_mask
|
||||||
+int refs_count
|
+List[int] _refs
|
||||||
+LRU _lru
|
+OrderedDict _lru
|
||||||
+alloc() int
|
+alloc() int
|
||||||
+free(idx, keep_cached)
|
+free(idx, keep_cached)
|
||||||
+inc_ref(idx)
|
+inc_ref(idx)
|
||||||
|
|
@ -564,9 +606,9 @@ classDiagram
|
||||||
+List[bool] _done
|
+List[bool] _done
|
||||||
+append(token, idx)
|
+append(token, idx)
|
||||||
+get_results() List[str]
|
+get_results() List[str]
|
||||||
+pop_all() List[str]
|
+pop_all() List[Tuple[int, str]]
|
||||||
+wait(timeout) bool
|
+wait(timeout) bool
|
||||||
+wait_completion()
|
+wait_completion(timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatMessage {
|
class ChatMessage {
|
||||||
|
|
@ -584,6 +626,65 @@ classDiagram
|
||||||
+Optional[str] stop
|
+Optional[str] stop
|
||||||
+Optional[int] n
|
+Optional[int] n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class AnthropicMessage {
|
||||||
|
+str role
|
||||||
|
+Union[str, List[Dict]] content
|
||||||
|
}
|
||||||
|
|
||||||
|
class MessagesRequest {
|
||||||
|
+List[AnthropicMessage] messages
|
||||||
|
+Optional[str] system
|
||||||
|
+float temperature
|
||||||
|
+float top_p
|
||||||
|
+int top_k
|
||||||
|
+int max_tokens
|
||||||
|
+bool stream
|
||||||
|
+Optional[List[str]] stop_sequences
|
||||||
|
}
|
||||||
|
|
||||||
|
class ProtocolHandler {
|
||||||
|
<<abstract>>
|
||||||
|
+build_prompt() str
|
||||||
|
+create_response_id() str
|
||||||
|
+format_stream_start(ctx) List[str]
|
||||||
|
+format_stream_token(ctx, token) str
|
||||||
|
+format_stream_end(ctx) List[str]
|
||||||
|
+format_non_stream_response(ctx, content) Dict
|
||||||
|
+handle() Union[StreamingResponse, Dict]
|
||||||
|
}
|
||||||
|
|
||||||
|
class OpenAIHandler {
|
||||||
|
+build_prompt() str
|
||||||
|
+create_response_id() str
|
||||||
|
}
|
||||||
|
|
||||||
|
class AnthropicHandler {
|
||||||
|
+List[str] stop_sequences
|
||||||
|
+build_prompt() str
|
||||||
|
+create_response_id() str
|
||||||
|
+on_token(ctx, token, stop_checker) Optional[str]
|
||||||
|
}
|
||||||
|
|
||||||
|
class StopChecker {
|
||||||
|
+check(text) Optional[str]
|
||||||
|
+trim(text, matched) str
|
||||||
|
}
|
||||||
|
|
||||||
|
class StreamContext {
|
||||||
|
+str resp_id
|
||||||
|
+int created
|
||||||
|
+str model
|
||||||
|
+int prompt_tokens
|
||||||
|
+int completion_tokens
|
||||||
|
+str accumulated
|
||||||
|
+Optional[str] stop_matched
|
||||||
|
}
|
||||||
|
|
||||||
|
class app {
|
||||||
|
<<singleton>>
|
||||||
|
+FastAPI app
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
|
@ -610,79 +711,104 @@ classDiagram
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Relationships
|
%% Relationships — UML notation: <|-- generalization, *-- composition, o-- aggregation, --> association, ..> dependency
|
||||||
TrainConfig --> BaseDataset : uses
|
|
||||||
TrainConfig ..> BaseStrategy : selects
|
%% --- Generalization (inheritance) ---
|
||||||
StrategyFactory ..> BaseStrategy : creates
|
|
||||||
BaseStrategy <|-- SEQStrategy
|
BaseStrategy <|-- SEQStrategy
|
||||||
BaseStrategy <|-- SFTStrategy
|
BaseStrategy <|-- SFTStrategy
|
||||||
BaseStrategy <|-- DPOStrategy
|
BaseStrategy <|-- DPOStrategy
|
||||||
BaseStrategy <|-- GRPOStrategy
|
BaseStrategy <|-- GRPOStrategy
|
||||||
DPOStrategy --> Transformer : uses
|
|
||||||
GRPOStrategy --> Transformer : uses
|
|
||||||
Trainer --> TrainConfig : uses
|
|
||||||
Trainer --> TrainContextBuilder : uses
|
|
||||||
Trainer --> TrainCallback : manages
|
|
||||||
TrainContextBuilder --> TrainContext : creates
|
|
||||||
TrainContextBuilder --> StrategyFactory : uses
|
|
||||||
Checkpoint ..> Checkpoint : serializes
|
|
||||||
TrainContext --> Checkpoint : manages
|
|
||||||
TrainContext --> BaseStrategy : uses
|
|
||||||
TrainContext --> BaseScheduler : uses
|
|
||||||
SchedulerFactory ..> BaseScheduler : creates
|
|
||||||
BaseScheduler <|-- CosineScheduler
|
BaseScheduler <|-- CosineScheduler
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
CallbackFactory ..> TrainCallback : creates
|
|
||||||
TrainCallback <|-- GradientClippingCallback
|
TrainCallback <|-- GradientClippingCallback
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
PagePool --> Allocator : composes
|
|
||||||
PagePool --> PrefixCache : composes
|
|
||||||
KVCache --> PagePool : composes
|
|
||||||
KVCache --> Storage : composes
|
|
||||||
KVCache --> TaskTable : composes
|
|
||||||
KvcacheView --> Storage : wraps
|
|
||||||
InferenceEngine --> InferenceScheduler : uses
|
|
||||||
InferenceEngine --> GenerationRequest : uses
|
|
||||||
InferenceEngine --> GenerateResult : creates
|
|
||||||
InferenceScheduler --> Task : manages
|
|
||||||
InferenceScheduler --> TaskStatus : uses
|
|
||||||
InferenceScheduler --> KVCache : uses
|
|
||||||
InferenceScheduler --> Transformer : uses
|
|
||||||
Task --> TaskStatus : uses
|
|
||||||
InferenceEngine --> Transformer : uses
|
|
||||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
|
||||||
SamplingPipeline --> BaseSamplingStrategy : composes
|
|
||||||
BaseDataset <|-- SEQDataset
|
BaseDataset <|-- SEQDataset
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
BaseDataset <|-- GRPODataset
|
BaseDataset <|-- GRPODataset
|
||||||
DatasetFactory ..> BaseDataset : creates
|
|
||||||
BaseStorage <|-- H5Storage
|
BaseStorage <|-- H5Storage
|
||||||
BaseStorage <|-- JSONStorage
|
BaseStorage <|-- JSONStorage
|
||||||
BaseDataset --> BaseStorage : uses
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
AutoModel <|-- Transformer
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
AutoModel --> ModelConfig : contains
|
|
||||||
Transformer --> DecoderBlock : uses
|
|
||||||
Transformer --> RotaryEmbedding : uses
|
|
||||||
Transformer --> Embedding : uses
|
|
||||||
DecoderBlock --> GQA : uses
|
|
||||||
DecoderBlock --> MLP : uses
|
|
||||||
DecoderBlock --> RMSNorm : uses
|
|
||||||
TrainContextBuilder --> ResumableDistributedSampler : creates
|
|
||||||
ResumableDistributedSampler --> BaseDataset : samples
|
|
||||||
ParallelModel <|-- RowParallelLinear
|
ParallelModel <|-- RowParallelLinear
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoTokenizer --> ChatTemplate : uses
|
AutoModel <|-- Transformer
|
||||||
BaseFactory <|-- AutoModel
|
BaseFactory <|-- AutoModel
|
||||||
|
BaseFactory <|-- AttnFactory
|
||||||
|
BaseFactory <|-- FFNFactory
|
||||||
BaseFactory <|-- DatasetFactory
|
BaseFactory <|-- DatasetFactory
|
||||||
BaseFactory <|-- StrategyFactory
|
BaseFactory <|-- StrategyFactory
|
||||||
BaseFactory <|-- SchedulerFactory
|
BaseFactory <|-- SchedulerFactory
|
||||||
BaseFactory <|-- CallbackFactory
|
BaseFactory <|-- CallbackFactory
|
||||||
|
ProtocolHandler <|-- OpenAIHandler
|
||||||
|
ProtocolHandler <|-- AnthropicHandler
|
||||||
|
|
||||||
|
%% --- Composition (strong ownership, part destroyed with whole) ---
|
||||||
|
KVCache *-- PagePool
|
||||||
|
KVCache *-- Storage
|
||||||
|
KVCache *-- TaskTable
|
||||||
|
KVCache *-- Allocator
|
||||||
|
KVCache *-- PrefixCache
|
||||||
|
InferenceEngine *-- InferenceScheduler
|
||||||
|
InferenceScheduler *-- KVCache
|
||||||
|
InferenceScheduler *-- Executor
|
||||||
|
InferenceScheduler *-- TaskManager
|
||||||
|
SamplingPipeline *-- BaseSamplingStrategy
|
||||||
|
TrainContextBuilder *-- TrainContext
|
||||||
|
Transformer *-- DecoderBlock
|
||||||
|
Transformer *-- RotaryEmbedding
|
||||||
|
Transformer *-- Embedding
|
||||||
|
DecoderBlock *-- RMSNorm
|
||||||
|
BaseDataset *-- BaseStorage
|
||||||
|
|
||||||
|
%% --- Aggregation (weak ownership) ---
|
||||||
|
AutoModel o-- ModelConfig
|
||||||
|
Trainer o-- TrainCallback
|
||||||
|
TrainContext o-- BaseStrategy
|
||||||
|
TrainContext o-- BaseScheduler
|
||||||
|
TrainContext o-- Checkpoint
|
||||||
|
AutoTokenizer o-- ChatTemplate
|
||||||
|
KvcacheView o-- Storage
|
||||||
|
BaseFactory o-- Registry
|
||||||
|
|
||||||
|
%% --- Dependency (uses temporarily) ---
|
||||||
|
TrainConfig ..> BaseStrategy : selects
|
||||||
|
StrategyFactory ..> BaseStrategy : creates
|
||||||
|
SchedulerFactory ..> BaseScheduler : creates
|
||||||
|
DatasetFactory ..> BaseDataset : creates
|
||||||
|
CallbackFactory ..> TrainCallback : creates
|
||||||
|
AttnFactory ..> GQA : creates
|
||||||
|
AttnFactory ..> MLA : creates
|
||||||
|
FFNFactory ..> MLP : creates
|
||||||
|
FFNFactory ..> DeepSeekMoE : creates
|
||||||
|
DecoderBlock ..> AttnFactory : uses
|
||||||
|
DecoderBlock ..> FFNFactory : uses
|
||||||
|
Trainer ..> TrainContextBuilder : uses
|
||||||
|
Trainer ..> Functions : spawns
|
||||||
|
TrainContextBuilder ..> StrategyFactory : uses
|
||||||
|
TrainContextBuilder ..> ResumableDistributedSampler : creates
|
||||||
|
Checkpoint ..> Checkpoint : serializes
|
||||||
|
CheckpointCallback ..> Checkpoint : creates
|
||||||
|
KVCache ..> KvcacheView : binds
|
||||||
|
InferenceEngine ..> GenerationRequest : uses
|
||||||
|
InferenceEngine ..> GenerateResult : creates
|
||||||
|
|
||||||
|
%% --- Association (general usage) ---
|
||||||
|
Trainer --> TrainConfig
|
||||||
|
DPOStrategy --> Transformer
|
||||||
|
GRPOStrategy --> Transformer
|
||||||
|
InferenceScheduler --> Task
|
||||||
|
InferenceScheduler --> TaskStatus
|
||||||
|
Task --> TaskStatus
|
||||||
|
InferenceEngine --> Transformer
|
||||||
|
Executor --> Transformer
|
||||||
|
Executor --> AutoTokenizer
|
||||||
|
TaskManager --> AutoTokenizer
|
||||||
|
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||||
|
ResumableDistributedSampler --> BaseDataset
|
||||||
```
|
```
|
||||||
|
|
||||||
### Module Overview
|
### Module Overview
|
||||||
|
|
@ -690,14 +816,14 @@ classDiagram
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5, save_json, load_json, create_storage, detect_format | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
|
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample, ChatMessage, ChatCompletionRequest, AnthropicMessage, MessagesRequest, OpenAIHandler, AnthropicHandler, ProtocolHandler, StreamContext, StopChecker, app, run_server | Inference service with continuous batching and paged KV cache |
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, only_on_rank, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
| **astrai.factory** | Registry, BaseFactory[T] | Generic component registration with decorator pattern |
|
||||||
|
|
||||||
### Design Patterns
|
### Design Patterns
|
||||||
|
|
||||||
|
|
@ -706,7 +832,7 @@ classDiagram
|
||||||
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, gradient clipping, metrics) |
|
||||||
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
|
||||||
|
|
@ -715,6 +841,8 @@ classDiagram
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||||
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
|
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
|
||||||
|
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | `handle()` template with stream/non-stream branches, protocol-specific format hooks |
|
||||||
|
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage`, `_STORAGE_REGISTRY` | Format-agnostic data access with registry-dispatch (HDF5 / JSON) |
|
||||||
|
|
||||||
### Core Relationships
|
### Core Relationships
|
||||||
|
|
||||||
|
|
@ -723,7 +851,7 @@ classDiagram
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||||
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 and JSON loading via `BaseStorage` (`H5Storage` / `JSONStorage`) with `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
||||||
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
|
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
|
||||||
|
|
@ -776,4 +904,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
|
||||||
|
|
||||||
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
> Document Update Time: 2026-05-15
|
||||||
|
|
|
||||||
|
|
@ -224,7 +224,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
|
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
|
||||||
| `top_p` | float | 1.0 | Nucleus sampling threshold |
|
| `top_p` | float | 1.0 | Nucleus sampling threshold |
|
||||||
| `top_k` | int | 50 | Top-k sampling parameter |
|
| `top_k` | int | 50 | Top-k sampling parameter |
|
||||||
| `max_tokens` | int | 1024 | Maximum tokens to generate |
|
| `max_tokens` | int | 2048 | Maximum tokens to generate |
|
||||||
| `stream` | bool | false | Enable streaming response |
|
| `stream` | bool | false | Enable streaming response |
|
||||||
|
|
||||||
**Response (non-streaming):**
|
**Response (non-streaming):**
|
||||||
|
|
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
|
||||||
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
|
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
> Document Update Time: 2026-05-15
|
||||||
|
|
@ -60,7 +60,7 @@
|
||||||
| Parameter | Description | Default | Used by |
|
| Parameter | Description | Default | Used by |
|
||||||
|-----------|-------------|---------|---------|
|
|-----------|-------------|---------|---------|
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
|
||||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||||
|
|
@ -98,7 +98,7 @@ python scripts/tools/train.py \
|
||||||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||||
| `top_k` | Top-k sampling count | 50 |
|
| `top_k` | Top-k sampling count | 50 |
|
||||||
| `max_tokens` | Maximum generation length | None (unlimited) |
|
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
|
||||||
| `stream` | Whether to stream output | False |
|
| `stream` | Whether to stream output | False |
|
||||||
|
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
@ -155,4 +155,4 @@ result = engine.generate(
|
||||||
| `stream=True` | Streaming output, yields token by token |
|
| `stream=True` | Streaming output, yields token by token |
|
||||||
| `stream=False` | Non-streaming output, returns complete result |
|
| `stream=False` | Non-streaming output, returns complete result |
|
||||||
|
|
||||||
> Document Update Time: 2026-05-14
|
> Document Update Time: 2026-05-15
|
||||||
Loading…
Reference in New Issue