docs: 修正文档错误并补充训练参数说明

- README: 补充训练参数速查表,完善训练命令示例
- design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等
  新增类),修正参数名和类型错误,统一泛型符号
- params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数
  (grpo_*、model_type、resume_dir),补充完整示例
- dataflow.md: _RadixNode 命名修正
This commit is contained in:
ViperEkura 2026-05-08 18:07:57 +08:00
parent 44d7a4e959
commit 78dc2bd41c
5 changed files with 213 additions and 100 deletions

View File

@ -27,9 +27,6 @@
## 📖 Table of Contents
<details open>
<summary><b>English</b></summary>
- [Features](#features)
- [Quick Start](#quick-start)
- [Documentation](#documentation)
@ -37,8 +34,6 @@
- [Community](#community)
- [License](#license)
</details>
---
<a id="english"></a>
@ -75,7 +70,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/param_path
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
```
#### Generate Text
@ -84,6 +86,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### Training Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--num_workers` | DataLoader workers | 4 |
| `--nprocs` | Number of GPUs | 1 |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Docker
Build and run with Docker (recommended for GPU environments):

View File

@ -76,7 +76,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/param_path
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
```
#### 文本生成
@ -85,6 +92,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### 训练参数
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--num_workers` | 数据加载线程数 | 4 |
| `--nprocs` | GPU 数量 | 1 |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境):

View File

@ -176,7 +176,7 @@ flowchart LR
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`RadixNode`**: Tree node structure for prefix caching
- **`_RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 6.3 Server (`server.py`)

View File

@ -85,8 +85,8 @@ classDiagram
}
class BaseSegmentFetcher {
+List~Tensor~ segments
+List~int~ cum_lengths
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
@ -191,7 +191,7 @@ classDiagram
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple~Tensor, Tensor~
+forward(x, start_pos) Tuple[Tensor, Tensor]
}
class Embedding {
@ -202,14 +202,14 @@ classDiagram
namespace tokenize {
class AutoTokenizer {
+List~str~ stop_ids
+List[str] stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List~int~
+encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union~str, List[int]~
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
@ -228,7 +228,7 @@ classDiagram
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List~str~
+list_names() List[str]
}
class BaseFactory {
@ -242,10 +242,10 @@ classDiagram
namespace trainer {
class Trainer {
+TrainConfig train_config
+List~TrainCallback~ callbacks
+List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List~TrainCallback~
+_get_default_callbacks() List[TrainCallback]
}
class TrainContext {
@ -308,7 +308,7 @@ classDiagram
}
class BaseScheduler {
+get_lr() List~float~
+get_lr() List[float]
+step()
}
@ -390,10 +390,8 @@ classDiagram
+InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+int max_prefix_len
+int max_prompt_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict
@ -403,10 +401,10 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+ModelConfig config
+Tuple kv_cache
+Tensor seq_mask
+PrefixCacheManager prefix_cache
+SlotAllocator slot_allocator
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -417,21 +415,22 @@ classDiagram
}
class PrefixCacheManager {
+RadixNode root
+_RadixNode root
+int max_capacity
+List lru
+insert(token_ids, slot)
+find_longest_prefix(token_ids) Tuple[int, int]
+OrderedDict _lru
+insert(token_ids, slot, slot_ver)
+find(token_ids) Tuple[int, int, int]
+pin(token_ids)
+release(token_ids)
+copy_kv(token_ids, target_slot, kv_cache, n_layers)
}
class RadixNode {
class _RadixNode {
+Dict children
+int hash
+int slot
+int slot_ver
+int ref_count
+float last_access
+List token_sequence
}
class Task {
@ -446,15 +445,69 @@ classDiagram
+int input_tokens
+int output_tokens
+int slot
+int prefix_len
+float arrival_time
+float finish_time
+Callable stream_callback
+is_finished(stop_ids) bool
}
class TaskStatus {
+str PENDING
+str RUNNING
+str FINISHED
+str ABORTED
<<enumeration>>
PENDING
RUNNING
FINISHED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
}
class SlotAllocator {
+int _max_slots
+int _free_mask
+List _versions
+alloc() int
+free(idx)
+occupy(idx)
+is_free(idx) bool
+version(idx) int
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
}
class Server {
@ -462,21 +515,12 @@ classDiagram
+predict(request)
}
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+List~Dict~ messages
+stream bool
}
class _Result {
+List~str~ tokens
+List~str~ results
+List~bool~ done_flags
+List[str] tokens
+List[str] results
+List[bool] done_flags
+append(token, idx)
+get_results() List~str~
+get_results() List[str]
}
class ChatMessage {
@ -485,13 +529,13 @@ classDiagram
}
class ChatCompletionRequest {
+List~ChatMessage~ messages
+List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional~str~ system_prompt
+Optional[str] system_prompt
}
class CompletionResponse {
@ -499,7 +543,7 @@ classDiagram
+str object
+int created
+str model
+List~Dict~ choices
+List[Dict] choices
}
}
@ -542,7 +586,6 @@ classDiagram
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
@ -553,11 +596,19 @@ classDiagram
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> SlotAllocator : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses
InferenceEngine --> _Result : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
@ -585,7 +636,7 @@ classDiagram
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
PrefixCacheManager --> _RadixNode : composes
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
@ -606,7 +657,7 @@ classDiagram
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -620,6 +671,8 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
@ -630,7 +683,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -4,70 +4,83 @@
### Basic Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 4 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm | 1.0 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
### Checkpoint
### Optimizer (AdamW)
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--resume_dir` | Resume training from specified path | - |
### Optimizer Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--random_seed` | Random seed | 3407 |
| `--num_workers` | DataLoader workers | 0 |
| `--prefetch_factor` | Prefetch factor for dataloader | None |
| `--pin_memory` | Enable pin_memory | False |
| `--no_pin_memory` | Disable pin_memory | - |
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--window_size` | Max input sequence length | model config `max_len` |
| `--stride` | Stride for sliding window over sequences | None |
| `--random_seed` | Random seed for reproducibility | 3407 |
| `--num_workers` | DataLoader worker processes | 4 |
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--nprocs` | Number of GPUs | 1 |
| `--device_type` | Device type (cuda/cpu) | cuda |
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--device_type` | Device type | cuda |
### Other Parameters
### Strategy-specific
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--window_size` | Maximum input sequence length | model config max_len |
| `--stride` | Input sequence stride | - |
| `--dpo_beta` | DPO beta value | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
| `--grpo_group_size` | GRPO group size | 4 |
| `--label_smoothing` | Label smoothing parameter | 0.1 |
| `--start_epoch` | Starting epoch | 0 |
| `--start_batch` | Starting batch | 0 |
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
### Usage Example
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
```
---
@ -89,14 +102,14 @@
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import Tokenizer
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = Tokenizer("your_model_dir")
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(