docs: 修正 assets/docs/ 类图、数据流、参数文档及贡献指南

- design.md: 新增 ProtocolHandler/OpenAIHandler/AnthropicHandler 等缺失类
- design.md: 新增 Template Method、Storage 设计模式
- dataflow.md: 修正 GQA/MLA 为独立条目,补充 JSON 存储后端
- params.md: 标注 label_smoothing CLI 默认与 strategy 默认差异
- introduction.md: 修正 max_tokens 默认值 1024→2048
- CONTRIBUTING.md: 重写(纯 Python 无 conda、补充 CI 步骤与常见问题)
- .github/PULL_REQUEST_TEMPLATE.md: 修正 lint 命令,去除多余注释要求
- .github/ISSUE_TEMPLATE/bug_report.md: 修正 label(enhancement→bug)
This commit is contained in:
ViperEkura 2026-05-15 22:45:37 +08:00
parent e12f1a7ee5
commit c169659611
7 changed files with 307 additions and 143 deletions

View File

@ -2,7 +2,7 @@
name: Bug report
about: Create a report to help us improve
title: "[BUG]"
labels: enhancement
labels: bug
assignees: ''
---

View File

@ -16,9 +16,9 @@ Please delete options that are not relevant.
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
## Checklist:
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check --fix .`)
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] Code is self-documenting (no unnecessary comments)
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works

View File

@ -1,68 +1,100 @@
# Contributing to AstrAI
Thank you for your interest in contributing to AstrAI! This document provides guidelines and steps for contributing.
Thank you for your interest in contributing! This document provides step-by-step guidelines.
## How to Contribute
## Quick Start
### Reporting Issues
If you encounter a bug or have a feature request, please open an issue on GitHub. Include as much detail as possible:
- A clear description of the problem or request.
- Steps to reproduce (for bugs).
- Your environment (Python version, OS, etc.).
### Submitting Changes
1. **Fork** the repository.
2. **Clone** your fork:
```bash
git clone https://github.com/your-username/AstrAI.git
cd AstrAI
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
```
3. **Create a feature branch**:
## Before You Commit
Run the following checks **in order** — CI will reject if any fail.
### 1. Format
```bash
git checkout -b feature/your-feature-name
ruff format .
```
4. **Make your changes**. Follow the code style guidelines below.
5. **Commit your changes** with a descriptive commit message:
> **Note**: `ruff format` may rename parameters (e.g. `mask``attn_mask`).
> Always review the diff after formatting.
### 2. Import sorting
```bash
git commit -m "Add: brief description of the change"
ruff check . --select I
```
6. **Push** to your fork:
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
```bash
git push origin feature/your-feature-name
ruff check . --select I --fix .
ruff format . # re-format after fix
```
7. **Open a Pull Request** (PR) against the `main` branch of the upstream repository.
## Code Style
### 3. Run tests
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
- Run Ruff to format and lint (requires conda environment `nlp`):
```bash
conda run -n nlp ruff format .
conda run -n nlp ruff check --fix .
python -u -m pytest tests/ -v
```
- The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`).
## Testing
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
If you add or modify functionality, please include appropriate tests.
### 4. (Optional) Full pre-commit check
If you have Git Bash available:
- Run the test suite with:
```bash
conda run -n nlp python -u -m pytest
bash scripts/pre_commit.sh
```
- Ensure all tests pass before submitting your PR.
This runs format check, import sort check, and tests in one go.
## Commit Style
```
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
- bullet point body (each ~60 chars)
```
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
- **Subject line** ends with no period.
- **Body** uses bullet points starting with `-`.
- No `(scope)` parentheses.
## Common Issues
| Problem | Cause | Fix |
|---------|-------|-----|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
## Submitting Changes
1. Fork the repo.
2. Create a feature branch: `git checkout -b feat/my-feature`
3. Make changes following the steps above.
4. Commit with the commit style above.
5. Push: `git push origin feat/my-feature`
6. Open a Pull Request against `main`.
## Code Review
All submissions will be reviewed. We may request changes or discuss alternatives. Please be responsive to feedback.
- All PRs are reviewed. We may request changes.
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
- Ensure all tests pass.
## License
By contributing, you agree that your contributions will be licensed under the same [GPL-3.0 License](LICENSE) that covers the project.
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
---
If you have any questions, feel free to ask in the [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
Happy contributing!
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.

View File

@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
## Overview
AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, storage backends
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
@ -71,7 +71,8 @@ flowchart LR
- **`BaseDataset`**: Abstract base class for windowed sequence sampling
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO)
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen"/"rejected"` (DPO), `"prompts"/"responses"/"masks"/"rewards"` (GRPO)
- Storage backends: HDF5 (`.h5`) or JSON (`.json`/`.jsonl`), auto-detected by `detect_format()`
#### 2.2 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
@ -85,8 +86,9 @@ flowchart LR
- RoPE position encoding, optional weight tying
#### 3.2 Submodules (`module.py`)
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
- **`DecoderBlock`**: Pre-LN (norm→attention→residual, norm→MLP→residual), uses `AttnFactory` / `FFNFactory`
- **`GQA`**: Grouped Query Attention (q_heads ÷ kv_heads = n_rep)
- **`MLA`**: Multi-head Latent Attention with KV compression (kv_lora_rank)
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
- **`RMSNorm`**: Layer normalization
@ -107,10 +109,12 @@ on_train_begin
for each accumulation window of batches: ← step phase
on_step_begin
for each batch in window: ← batch phase
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
on_batch_begin → strategy(batch) → loss
(loss / window_size).backward() → on_batch_end
iteration += 1
on_step_end
optimizer.step() → zero_grad
scheduler.step() ← per step, not per batch
on_epoch_end
on_train_end
@ -120,7 +124,7 @@ Key points:
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
- `on_batch_*` fires every batch, wrapping loss computation
- `GradientClippingCallback` fires on `on_step_end`
- LR scheduler steps inline (no `SchedulerCallback` class)
- LR scheduler steps inline (no `SchedulerCallback` class), once per optimizer step
#### 4.3 Strategy (`strategy.py`)
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
@ -167,7 +171,7 @@ Background thread runs continuously:
### 6. Tokenizer Module
- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template`
- **`AutoTokenizer`**: Wraps HuggingFace `tokenizers.Tokenizer` (not `transformers`); `encode`/`decode`/`apply_chat_template`
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
### 7. Factory & Parallel
@ -195,14 +199,14 @@ Background thread runs continuously:
- Computes task-specific loss (cross-entropy, DPO, GRPO)
5. **Backward & Accumulation**
- `loss = raw_loss / accumulation_steps`
- `loss.backward()` accumulates gradients
- `stand_loss = loss / step_batch_nums` (divide by actual batch count in this window)
- `stand_loss.backward()` accumulates gradients
- Every `accumulation_steps` batches: `optimizer.step()``zero_grad()`
- Every batch: `scheduler.step()` updates learning rate
- Every optimizer step: `scheduler.step()` updates learning rate
6. **Checkpoint**
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
- Does NOT save optimizer/scheduler state (resume resets those)
- Does NOT save optimizer/scheduler state by default; `Checkpoint.extra` or `save_extra_fn` can store arbitrary additional data
## Inference Data Flow — Detailed Steps
@ -230,8 +234,8 @@ Background thread runs continuously:
## Checkpoint & Serialization
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted.
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata + optional extras. Optimizer/scheduler state is NOT persisted by default but can be stored via `extra`.
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
> Document Update Time: 2026-05-14
> Document Update Time: 2026-05-15

View File

@ -22,6 +22,12 @@ classDiagram
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+str attn_type
+str ffn_type
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str moe_topk_method
+load(config_path) ModelConfig
+save(config_path)
}
@ -42,7 +48,7 @@ classDiagram
+int ckpt_interval
+int random_seed
+int num_workers
+int prefetch_factor
+Optional[int] prefetch_factor
+bool pin_memory
+int nprocs
+str backend
@ -118,8 +124,8 @@ classDiagram
}
class ResumableDistributedSampler {
+int epoch
+int iter
+int start_epoch
+int start_iter
}
class DatasetFactory {
@ -135,6 +141,7 @@ classDiagram
+dict state_dict
+int epoch
+int iteration
+dict extra
+save(save_dir)
+load(save_dir) Checkpoint
}
@ -158,15 +165,15 @@ classDiagram
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
+load_state_dict(state_dict)
+state_dict()
}
class DecoderBlock {
+GQA attention
+nn.Module attention # GQA or MLA via AttnFactory
+RMSNorm input_norm
+MLP mlp
+nn.Module mlp # MLP or DeepSeekMoE via FFNFactory
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
}
@ -175,8 +182,12 @@ classDiagram
+int n_heads
+int n_kv_heads
+int head_dim
+int n_rep
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+Linear gate # only if use_gated_attention
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
@ -187,8 +198,11 @@ classDiagram
+int kv_lora_rank
+int qk_nope_head_dim
+int qk_rope_head_dim
+int n_rep
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+Linear gate # only if use_gated_attention
+RMSNorm kv_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
@ -198,6 +212,25 @@ classDiagram
+forward(x) Tensor
}
class DeepSeekMoE {
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str topk_method
+Linear router
+ModuleList shared_experts
+ModuleList routed_experts
+forward(x) Tensor
}
class AttnFactory {
+create(attn_type, **kwargs) nn.Module
}
class FFNFactory {
+create(ffn_type, dim, dim_ffn, **kwargs) nn.Module
}
class RMSNorm {
+Parameter weight
+float norm_eps
@ -206,7 +239,7 @@ classDiagram
class Linear {
+Parameter weight
+Parameter bias
+Optional[Parameter] bias # only if bias=True
+forward(x) Tensor
}
@ -365,7 +398,7 @@ classDiagram
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
+on_step_end(context)
}
class CheckpointCallback {
@ -410,15 +443,24 @@ classDiagram
+shutdown()
}
class InferenceScheduler {
+nn.Module model
class Executor {
+AutoModel model
+AutoTokenizer tokenizer
+KVCache page_cache
+execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int]
}
class InferenceScheduler {
+KVCache _page_cache
+Executor _executor
+TaskManager _task_mgr
+bool _running
+Thread _loop_thread
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+TaskManager _task_mgr
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
@ -428,8 +470,8 @@ classDiagram
class Allocator {
+int _free_mask
+int refs_count
+LRU _lru
+List[int] _refs
+OrderedDict _lru
+alloc() int
+free(idx, keep_cached)
+inc_ref(idx)
@ -564,9 +606,9 @@ classDiagram
+List[bool] _done
+append(token, idx)
+get_results() List[str]
+pop_all() List[str]
+pop_all() List[Tuple[int, str]]
+wait(timeout) bool
+wait_completion()
+wait_completion(timeout)
}
class ChatMessage {
@ -584,6 +626,65 @@ classDiagram
+Optional[str] stop
+Optional[int] n
}
class AnthropicMessage {
+str role
+Union[str, List[Dict]] content
}
class MessagesRequest {
+List[AnthropicMessage] messages
+Optional[str] system
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[List[str]] stop_sequences
}
class ProtocolHandler {
<<abstract>>
+build_prompt() str
+create_response_id() str
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict]
}
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+List[str] stop_sequences
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
}
class StopChecker {
+check(text) Optional[str]
+trim(text, matched) str
}
class StreamContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
}
class app {
<<singleton>>
+FastAPI app
}
}
namespace parallel {
@ -610,79 +711,104 @@ classDiagram
}
}
%% Relationships
TrainConfig --> BaseDataset : uses
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
%% Relationships — UML notation: <|-- generalization, *-- composition, o-- aggregation, --> association, ..> dependency
%% --- Generalization (inheritance) ---
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : uses
Trainer --> TrainContextBuilder : uses
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
TrainContextBuilder --> StrategyFactory : uses
Checkpoint ..> Checkpoint : serializes
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
PagePool --> Allocator : composes
PagePool --> PrefixCache : composes
KVCache --> PagePool : composes
KVCache --> Storage : composes
KVCache --> TaskTable : composes
KvcacheView --> Storage : wraps
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
InferenceEngine --> GenerateResult : creates
InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> KVCache : uses
InferenceScheduler --> Transformer : uses
Task --> TaskStatus : uses
InferenceEngine --> Transformer : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseDataset --> BaseStorage : uses
MultiSegmentFetcher --> BaseSegmentFetcher : uses
AutoModel <|-- Transformer
AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses
Transformer --> RotaryEmbedding : uses
Transformer --> Embedding : uses
DecoderBlock --> GQA : uses
DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses
TrainContextBuilder --> ResumableDistributedSampler : creates
ResumableDistributedSampler --> BaseDataset : samples
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
AutoModel <|-- Transformer
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
BaseFactory <|-- FFNFactory
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
KVCache *-- Allocator
KVCache *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
TrainContextBuilder *-- TrainContext
Transformer *-- DecoderBlock
Transformer *-- RotaryEmbedding
Transformer *-- Embedding
DecoderBlock *-- RMSNorm
BaseDataset *-- BaseStorage
%% --- Aggregation (weak ownership) ---
AutoModel o-- ModelConfig
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
AutoTokenizer o-- ChatTemplate
KvcacheView o-- Storage
BaseFactory o-- Registry
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
SchedulerFactory ..> BaseScheduler : creates
DatasetFactory ..> BaseDataset : creates
CallbackFactory ..> TrainCallback : creates
AttnFactory ..> GQA : creates
AttnFactory ..> MLA : creates
FFNFactory ..> MLP : creates
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
Trainer ..> TrainContextBuilder : uses
Trainer ..> Functions : spawns
TrainContextBuilder ..> StrategyFactory : uses
TrainContextBuilder ..> ResumableDistributedSampler : creates
Checkpoint ..> Checkpoint : serializes
CheckpointCallback ..> Checkpoint : creates
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
DPOStrategy --> Transformer
GRPOStrategy --> Transformer
InferenceScheduler --> Task
InferenceScheduler --> TaskStatus
Task --> TaskStatus
InferenceEngine --> Transformer
Executor --> Transformer
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
```
### Module Overview
@ -690,14 +816,14 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5, save_json, load_json, create_storage, detect_format | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample, ChatMessage, ChatCompletionRequest, AnthropicMessage, MessagesRequest, OpenAIHandler, AnthropicHandler, ProtocolHandler, StreamContext, StopChecker, app, run_server | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, only_on_rank, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory[T] | Generic component registration with decorator pattern |
### Design Patterns
@ -706,7 +832,7 @@ classDiagram
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, gradient clipping, metrics) |
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
@ -715,6 +841,8 @@ classDiagram
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | `handle()` template with stream/non-stream branches, protocol-specific format hooks |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage`, `_STORAGE_REGISTRY` | Format-agnostic data access with registry-dispatch (HDF5 / JSON) |
### Core Relationships
@ -723,7 +851,7 @@ classDiagram
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 and JSON loading via `BaseStorage` (`H5Storage` / `JSONStorage`) with `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
@ -776,4 +904,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-05-14
> Document Update Time: 2026-05-15

View File

@ -224,7 +224,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
| `top_p` | float | 1.0 | Nucleus sampling threshold |
| `top_k` | int | 50 | Top-k sampling parameter |
| `max_tokens` | int | 1024 | Maximum tokens to generate |
| `max_tokens` | int | 2048 | Maximum tokens to generate |
| `stream` | bool | false | Enable streaming response |
**Response (non-streaming):**
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
```
> Document Update Time: 2026-05-14
> Document Update Time: 2026-05-15

View File

@ -60,7 +60,7 @@
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
@ -98,7 +98,7 @@ python scripts/tools/train.py \
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (unlimited) |
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
| `stream` | Whether to stream output | False |
### Usage Example
@ -155,4 +155,4 @@ result = engine.generate(
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-05-14
> Document Update Time: 2026-05-15