docs: 修正 assets/docs/ 类图、数据流、参数文档及贡献指南

- design.md: 新增 ProtocolHandler/OpenAIHandler/AnthropicHandler 等缺失类
- design.md: 新增 Template Method、Storage 设计模式
- dataflow.md: 修正 GQA/MLA 为独立条目,补充 JSON 存储后端
- params.md: 标注 label_smoothing CLI 默认与 strategy 默认差异
- introduction.md: 修正 max_tokens 默认值 1024→2048
- CONTRIBUTING.md: 重写(纯 Python 无 conda、补充 CI 步骤与常见问题)
- .github/PULL_REQUEST_TEMPLATE.md: 修正 lint 命令,去除多余注释要求
- .github/ISSUE_TEMPLATE/bug_report.md: 修正 label(enhancement→bug)
This commit is contained in:
ViperEkura 2026-05-15 22:45:37 +08:00
parent e12f1a7ee5
commit c169659611
7 changed files with 307 additions and 143 deletions

View File

@ -2,7 +2,7 @@
name: Bug report name: Bug report
about: Create a report to help us improve about: Create a report to help us improve
title: "[BUG]" title: "[BUG]"
labels: enhancement labels: bug
assignees: '' assignees: ''
--- ---

View File

@ -16,9 +16,9 @@ Please delete options that are not relevant.
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
## Checklist: ## Checklist:
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check --fix .`) - [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
- [ ] I have performed a self-review of my own code - [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas - [ ] Code is self-documenting (no unnecessary comments)
- [ ] I have made corresponding changes to the documentation - [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings - [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added tests that prove my fix is effective or that my feature works

View File

@ -1,68 +1,100 @@
# Contributing to AstrAI # Contributing to AstrAI
Thank you for your interest in contributing to AstrAI! This document provides guidelines and steps for contributing. Thank you for your interest in contributing! This document provides step-by-step guidelines.
## How to Contribute ## Quick Start
### Reporting Issues
If you encounter a bug or have a feature request, please open an issue on GitHub. Include as much detail as possible:
- A clear description of the problem or request.
- Steps to reproduce (for bugs).
- Your environment (Python version, OS, etc.).
### Submitting Changes
1. **Fork** the repository.
2. **Clone** your fork:
```bash ```bash
git clone https://github.com/your-username/AstrAI.git git clone https://github.com/your-username/AstrAI.git
cd AstrAI cd AstrAI
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
``` ```
3. **Create a feature branch**:
## Before You Commit
Run the following checks **in order** — CI will reject if any fail.
### 1. Format
```bash ```bash
git checkout -b feature/your-feature-name ruff format .
``` ```
4. **Make your changes**. Follow the code style guidelines below.
5. **Commit your changes** with a descriptive commit message: > **Note**: `ruff format` may rename parameters (e.g. `mask``attn_mask`).
> Always review the diff after formatting.
### 2. Import sorting
```bash ```bash
git commit -m "Add: brief description of the change" ruff check . --select I
``` ```
6. **Push** to your fork:
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
```bash ```bash
git push origin feature/your-feature-name ruff check . --select I --fix .
ruff format . # re-format after fix
``` ```
7. **Open a Pull Request** (PR) against the `main` branch of the upstream repository.
## Code Style ### 3. Run tests
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
- Run Ruff to format and lint (requires conda environment `nlp`):
```bash ```bash
conda run -n nlp ruff format . python -u -m pytest tests/ -v
conda run -n nlp ruff check --fix .
``` ```
- The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`).
## Testing > Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
If you add or modify functionality, please include appropriate tests. ### 4. (Optional) Full pre-commit check
If you have Git Bash available:
- Run the test suite with:
```bash ```bash
conda run -n nlp python -u -m pytest bash scripts/pre_commit.sh
``` ```
- Ensure all tests pass before submitting your PR.
This runs format check, import sort check, and tests in one go.
## Commit Style
```
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
- bullet point body (each ~60 chars)
```
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
- **Subject line** ends with no period.
- **Body** uses bullet points starting with `-`.
- No `(scope)` parentheses.
## Common Issues
| Problem | Cause | Fix |
|---------|-------|-----|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
## Submitting Changes
1. Fork the repo.
2. Create a feature branch: `git checkout -b feat/my-feature`
3. Make changes following the steps above.
4. Commit with the commit style above.
5. Push: `git push origin feat/my-feature`
6. Open a Pull Request against `main`.
## Code Review ## Code Review
All submissions will be reviewed. We may request changes or discuss alternatives. Please be responsive to feedback. - All PRs are reviewed. We may request changes.
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
- Ensure all tests pass.
## License ## License
By contributing, you agree that your contributions will be licensed under the same [GPL-3.0 License](LICENSE) that covers the project. By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
--- ---
If you have any questions, feel free to ask in the [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue. Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
Happy contributing!

View File

@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
## Overview ## Overview
AstrAI adopts a modular design with the following main components: AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, storage backends
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules - **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
@ -71,7 +71,8 @@ flowchart LR
- **`BaseDataset`**: Abstract base class for windowed sequence sampling - **`BaseDataset`**: Abstract base class for windowed sequence sampling
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range - **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`) - **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO) - Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen"/"rejected"` (DPO), `"prompts"/"responses"/"masks"/"rewards"` (GRPO)
- Storage backends: HDF5 (`.h5`) or JSON (`.json`/`.jsonl`), auto-detected by `detect_format()`
#### 2.2 Sampler (`sampler.py`) #### 2.2 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last - **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
@ -85,8 +86,9 @@ flowchart LR
- RoPE position encoding, optional weight tying - RoPE position encoding, optional weight tying
#### 3.2 Submodules (`module.py`) #### 3.2 Submodules (`module.py`)
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm - **`DecoderBlock`**: Pre-LN (norm→attention→residual, norm→MLP→residual), uses `AttnFactory` / `FFNFactory`
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention) - **`GQA`**: Grouped Query Attention (q_heads ÷ kv_heads = n_rep)
- **`MLA`**: Multi-head Latent Attention with KV compression (kv_lora_rank)
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection - **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis) - **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
- **`RMSNorm`**: Layer normalization - **`RMSNorm`**: Layer normalization
@ -107,10 +109,12 @@ on_train_begin
for each accumulation window of batches: ← step phase for each accumulation window of batches: ← step phase
on_step_begin on_step_begin
for each batch in window: ← batch phase for each batch in window: ← batch phase
on_batch_begin → strategy(batch) → loss → backward → on_batch_end on_batch_begin → strategy(batch) → loss
(loss / window_size).backward() → on_batch_end
iteration += 1 iteration += 1
on_step_end on_step_end
optimizer.step() → zero_grad optimizer.step() → zero_grad
scheduler.step() ← per step, not per batch
on_epoch_end on_epoch_end
on_train_end on_train_end
@ -120,7 +124,7 @@ Key points:
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook - `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
- `on_batch_*` fires every batch, wrapping loss computation - `on_batch_*` fires every batch, wrapping loss computation
- `GradientClippingCallback` fires on `on_step_end` - `GradientClippingCallback` fires on `on_step_end`
- LR scheduler steps inline (no `SchedulerCallback` class) - LR scheduler steps inline (no `SchedulerCallback` class), once per optimizer step
#### 4.3 Strategy (`strategy.py`) #### 4.3 Strategy (`strategy.py`)
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing - **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
@ -167,7 +171,7 @@ Background thread runs continuously:
### 6. Tokenizer Module ### 6. Tokenizer Module
- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template` - **`AutoTokenizer`**: Wraps HuggingFace `tokenizers.Tokenizer` (not `transformers`); `encode`/`decode`/`apply_chat_template`
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat - **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
### 7. Factory & Parallel ### 7. Factory & Parallel
@ -195,14 +199,14 @@ Background thread runs continuously:
- Computes task-specific loss (cross-entropy, DPO, GRPO) - Computes task-specific loss (cross-entropy, DPO, GRPO)
5. **Backward & Accumulation** 5. **Backward & Accumulation**
- `loss = raw_loss / accumulation_steps` - `stand_loss = loss / step_batch_nums` (divide by actual batch count in this window)
- `loss.backward()` accumulates gradients - `stand_loss.backward()` accumulates gradients
- Every `accumulation_steps` batches: `optimizer.step()``zero_grad()` - Every `accumulation_steps` batches: `optimizer.step()``zero_grad()`
- Every batch: `scheduler.step()` updates learning rate - Every optimizer step: `scheduler.step()` updates learning rate
6. **Checkpoint** 6. **Checkpoint**
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations - `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
- Does NOT save optimizer/scheduler state (resume resets those) - Does NOT save optimizer/scheduler state by default; `Checkpoint.extra` or `save_extra_fn` can store arbitrary additional data
## Inference Data Flow — Detailed Steps ## Inference Data Flow — Detailed Steps
@ -230,8 +234,8 @@ Background thread runs continuously:
## Checkpoint & Serialization ## Checkpoint & Serialization
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted. - **Training Checkpoint**: safetensors weights + epoch/iteration metadata + optional extras. Optimizer/scheduler state is NOT persisted by default but can be stored via `extra`.
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format. - **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data. - **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
> Document Update Time: 2026-05-14 > Document Update Time: 2026-05-15

View File

@ -22,6 +22,12 @@ classDiagram
+int n_kv_heads +int n_kv_heads
+bool use_qk_norm +bool use_qk_norm
+bool use_gated_attention +bool use_gated_attention
+str attn_type
+str ffn_type
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str moe_topk_method
+load(config_path) ModelConfig +load(config_path) ModelConfig
+save(config_path) +save(config_path)
} }
@ -42,7 +48,7 @@ classDiagram
+int ckpt_interval +int ckpt_interval
+int random_seed +int random_seed
+int num_workers +int num_workers
+int prefetch_factor +Optional[int] prefetch_factor
+bool pin_memory +bool pin_memory
+int nprocs +int nprocs
+str backend +str backend
@ -118,8 +124,8 @@ classDiagram
} }
class ResumableDistributedSampler { class ResumableDistributedSampler {
+int epoch +int start_epoch
+int iter +int start_iter
} }
class DatasetFactory { class DatasetFactory {
@ -135,6 +141,7 @@ classDiagram
+dict state_dict +dict state_dict
+int epoch +int epoch
+int iteration +int iteration
+dict extra
+save(save_dir) +save(save_dir)
+load(save_dir) Checkpoint +load(save_dir) Checkpoint
} }
@ -158,15 +165,15 @@ classDiagram
+ModuleList layers +ModuleList layers
+RMSNorm norm +RMSNorm norm
+Linear lm_head +Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor +forward(input_ids, input_mask, paged_cache, position_ids) Dict
+load_state_dict(state_dict) +load_state_dict(state_dict)
+state_dict() +state_dict()
} }
class DecoderBlock { class DecoderBlock {
+GQA attention +nn.Module attention # GQA or MLA via AttnFactory
+RMSNorm input_norm +RMSNorm input_norm
+MLP mlp +nn.Module mlp # MLP or DeepSeekMoE via FFNFactory
+RMSNorm post_attention_norm +RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor +forward(x, rotary_emb, attention_mask, paged_cache) Tensor
} }
@ -175,8 +182,12 @@ classDiagram
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
+int n_rep
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj +Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm +Linear gate # only if use_gated_attention
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor +forward(x, rotary_emb, attn_mask, paged_cache) Tensor
} }
@ -187,8 +198,11 @@ classDiagram
+int kv_lora_rank +int kv_lora_rank
+int qk_nope_head_dim +int qk_nope_head_dim
+int qk_rope_head_dim +int qk_rope_head_dim
+int n_rep
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj +Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj +Linear o_proj
+Linear gate # only if use_gated_attention
+RMSNorm kv_norm +RMSNorm kv_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor +forward(x, rotary_emb, attn_mask, paged_cache) Tensor
} }
@ -198,6 +212,25 @@ classDiagram
+forward(x) Tensor +forward(x) Tensor
} }
class DeepSeekMoE {
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str topk_method
+Linear router
+ModuleList shared_experts
+ModuleList routed_experts
+forward(x) Tensor
}
class AttnFactory {
+create(attn_type, **kwargs) nn.Module
}
class FFNFactory {
+create(ffn_type, dim, dim_ffn, **kwargs) nn.Module
}
class RMSNorm { class RMSNorm {
+Parameter weight +Parameter weight
+float norm_eps +float norm_eps
@ -206,7 +239,7 @@ classDiagram
class Linear { class Linear {
+Parameter weight +Parameter weight
+Parameter bias +Optional[Parameter] bias # only if bias=True
+forward(x) Tensor +forward(x) Tensor
} }
@ -365,7 +398,7 @@ classDiagram
class GradientClippingCallback { class GradientClippingCallback {
+float max_grad_norm +float max_grad_norm
+on_step_begin(context) +on_step_end(context)
} }
class CheckpointCallback { class CheckpointCallback {
@ -410,15 +443,24 @@ classDiagram
+shutdown() +shutdown()
} }
class InferenceScheduler { class Executor {
+nn.Module model +AutoModel model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache page_cache
+execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int]
}
class InferenceScheduler {
+KVCache _page_cache +KVCache _page_cache
+Executor _executor
+TaskManager _task_mgr
+bool _running
+Thread _loop_thread
+int max_batch_size +int max_batch_size
+int max_seq_len +int max_seq_len
+int max_prompt_len +int max_prompt_len
+int page_size +int page_size
+TaskManager _task_mgr
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
@ -428,8 +470,8 @@ classDiagram
class Allocator { class Allocator {
+int _free_mask +int _free_mask
+int refs_count +List[int] _refs
+LRU _lru +OrderedDict _lru
+alloc() int +alloc() int
+free(idx, keep_cached) +free(idx, keep_cached)
+inc_ref(idx) +inc_ref(idx)
@ -564,9 +606,9 @@ classDiagram
+List[bool] _done +List[bool] _done
+append(token, idx) +append(token, idx)
+get_results() List[str] +get_results() List[str]
+pop_all() List[str] +pop_all() List[Tuple[int, str]]
+wait(timeout) bool +wait(timeout) bool
+wait_completion() +wait_completion(timeout)
} }
class ChatMessage { class ChatMessage {
@ -584,6 +626,65 @@ classDiagram
+Optional[str] stop +Optional[str] stop
+Optional[int] n +Optional[int] n
} }
class AnthropicMessage {
+str role
+Union[str, List[Dict]] content
}
class MessagesRequest {
+List[AnthropicMessage] messages
+Optional[str] system
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[List[str]] stop_sequences
}
class ProtocolHandler {
<<abstract>>
+build_prompt() str
+create_response_id() str
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict]
}
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+List[str] stop_sequences
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
}
class StopChecker {
+check(text) Optional[str]
+trim(text, matched) str
}
class StreamContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
}
class app {
<<singleton>>
+FastAPI app
}
} }
namespace parallel { namespace parallel {
@ -610,79 +711,104 @@ classDiagram
} }
} }
%% Relationships %% Relationships — UML notation: <|-- generalization, *-- composition, o-- aggregation, --> association, ..> dependency
TrainConfig --> BaseDataset : uses
TrainConfig ..> BaseStrategy : selects %% --- Generalization (inheritance) ---
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : uses
Trainer --> TrainContextBuilder : uses
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
TrainContextBuilder --> StrategyFactory : uses
Checkpoint ..> Checkpoint : serializes
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
PagePool --> Allocator : composes
PagePool --> PrefixCache : composes
KVCache --> PagePool : composes
KVCache --> Storage : composes
KVCache --> TaskTable : composes
KvcacheView --> Storage : wraps
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
InferenceEngine --> GenerateResult : creates
InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> KVCache : uses
InferenceScheduler --> Transformer : uses
Task --> TaskStatus : uses
InferenceEngine --> Transformer : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseStorage <|-- H5Storage BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage BaseStorage <|-- JSONStorage
BaseDataset --> BaseStorage : uses BaseSamplingStrategy <|-- TemperatureStrategy
MultiSegmentFetcher --> BaseSegmentFetcher : uses BaseSamplingStrategy <|-- TopKStrategy
AutoModel <|-- Transformer BaseSamplingStrategy <|-- TopPStrategy
AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses
Transformer --> RotaryEmbedding : uses
Transformer --> Embedding : uses
DecoderBlock --> GQA : uses
DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses
TrainContextBuilder --> ResumableDistributedSampler : creates
ResumableDistributedSampler --> BaseDataset : samples
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses AutoModel <|-- Transformer
BaseFactory <|-- AutoModel BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
BaseFactory <|-- FFNFactory
BaseFactory <|-- DatasetFactory BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory BaseFactory <|-- CallbackFactory
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
KVCache *-- Allocator
KVCache *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
TrainContextBuilder *-- TrainContext
Transformer *-- DecoderBlock
Transformer *-- RotaryEmbedding
Transformer *-- Embedding
DecoderBlock *-- RMSNorm
BaseDataset *-- BaseStorage
%% --- Aggregation (weak ownership) ---
AutoModel o-- ModelConfig
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
AutoTokenizer o-- ChatTemplate
KvcacheView o-- Storage
BaseFactory o-- Registry
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
SchedulerFactory ..> BaseScheduler : creates
DatasetFactory ..> BaseDataset : creates
CallbackFactory ..> TrainCallback : creates
AttnFactory ..> GQA : creates
AttnFactory ..> MLA : creates
FFNFactory ..> MLP : creates
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
Trainer ..> TrainContextBuilder : uses
Trainer ..> Functions : spawns
TrainContextBuilder ..> StrategyFactory : uses
TrainContextBuilder ..> ResumableDistributedSampler : creates
Checkpoint ..> Checkpoint : serializes
CheckpointCallback ..> Checkpoint : creates
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
DPOStrategy --> Transformer
GRPOStrategy --> Transformer
InferenceScheduler --> Task
InferenceScheduler --> TaskStatus
Task --> TaskStatus
InferenceEngine --> Transformer
Executor --> Transformer
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
``` ```
### Module Overview ### Module Overview
@ -690,14 +816,14 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5, save_json, load_json, create_storage, detect_format | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management | | **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample, ChatMessage, ChatCompletionRequest, AnthropicMessage, MessagesRequest, OpenAIHandler, AnthropicHandler, ProtocolHandler, StreamContext, StopChecker, app, run_server | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, only_on_rank, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration | | **astrai.factory** | Registry, BaseFactory[T] | Generic component registration with decorator pattern |
### Design Patterns ### Design Patterns
@ -706,7 +832,7 @@ classDiagram
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | | **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | | **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks | | **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, gradient clipping, metrics) |
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint | | **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
@ -715,6 +841,8 @@ classDiagram
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation | | **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | `handle()` template with stream/non-stream branches, protocol-specific format hooks |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage`, `_STORAGE_REGISTRY` | Format-agnostic data access with registry-dispatch (HDF5 / JSON) |
### Core Relationships ### Core Relationships
@ -723,7 +851,7 @@ classDiagram
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer` 5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 and JSON loading via `BaseStorage` (`H5Storage` / `JSONStorage`) with `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler) 8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration 9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
@ -776,4 +904,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-05-14 > Document Update Time: 2026-05-15

View File

@ -224,7 +224,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) | | `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
| `top_p` | float | 1.0 | Nucleus sampling threshold | | `top_p` | float | 1.0 | Nucleus sampling threshold |
| `top_k` | int | 50 | Top-k sampling parameter | | `top_k` | int | 50 | Top-k sampling parameter |
| `max_tokens` | int | 1024 | Maximum tokens to generate | | `max_tokens` | int | 2048 | Maximum tokens to generate |
| `stream` | bool | false | Enable streaming response | | `stream` | bool | false | Enable streaming response |
**Response (non-streaming):** **Response (non-streaming):**
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0} # {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
``` ```
> Document Update Time: 2026-05-14 > Document Update Time: 2026-05-15

View File

@ -60,7 +60,7 @@
| Parameter | Description | Default | Used by | | Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------| |-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` | | `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` | | `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` | | `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` | | `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` | | `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
@ -98,7 +98,7 @@ python scripts/tools/train.py \
| `temperature` | Sampling temperature (higher = more random) | 1.0 | | `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 | | `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 | | `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (unlimited) | | `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
| `stream` | Whether to stream output | False | | `stream` | Whether to stream output | False |
### Usage Example ### Usage Example
@ -155,4 +155,4 @@ result = engine.generate(
| `stream=True` | Streaming output, yields token by token | | `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result | | `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-05-14 > Document Update Time: 2026-05-15