From 78dc2bd41c47faf428f602f370d0a18eb905f696 Mon Sep 17 00:00:00 2001
From: ViperEkura <3081035982@qq.com>
Date: Fri, 8 May 2026 18:07:57 +0800
Subject: [PATCH] =?UTF-8?q?docs:=20=E4=BF=AE=E6=AD=A3=E6=96=87=E6=A1=A3?=
=?UTF-8?q?=E9=94=99=E8=AF=AF=E5=B9=B6=E8=A1=A5=E5=85=85=E8=AE=AD=E7=BB=83?=
=?UTF-8?q?=E5=8F=82=E6=95=B0=E8=AF=B4=E6=98=8E?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- README: 补充训练参数速查表,完善训练命令示例
- design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等
新增类),修正参数名和类型错误,统一泛型符号
- params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数
(grpo_*、model_type、resume_dir),补充完整示例
- dataflow.md: _RadixNode 命名修正
---
README.md | 33 ++++++--
assets/docs/README-zh-CN.md | 28 ++++++-
assets/docs/dataflow.md | 2 +-
assets/docs/design.md | 145 ++++++++++++++++++++++++------------
assets/docs/params.md | 105 ++++++++++++++------------
5 files changed, 213 insertions(+), 100 deletions(-)
diff --git a/README.md b/README.md
index b19041d..b049505 100644
--- a/README.md
+++ b/README.md
@@ -27,9 +27,6 @@
## 📖 Table of Contents
-
-English
-
- [Features](#features)
- [Quick Start](#quick-start)
- [Documentation](#documentation)
@@ -37,8 +34,6 @@
- [Community](#community)
- [License](#license)
-
-
---
@@ -75,7 +70,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
- --param_path=/path/to/param_path
+ --param_path=/path/to/model \
+ --n_epoch=3 \
+ --batch_size=4 \
+ --accumulation_steps=8 \
+ --max_lr=3e-4 \
+ --warmup_steps=2000 \
+ --ckpt_interval=5000 \
+ --ckpt_dir=./checkpoints
```
#### Generate Text
@@ -84,6 +86,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
+#### Training Parameters
+
+| Parameter | Description | Default |
+|-----------|-------------|---------|
+| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
+| `--data_root_path` | Dataset root directory | required |
+| `--param_path` | Model / checkpoint path | required |
+| `--n_epoch` | Training epochs | 1 |
+| `--batch_size` | Batch size | 1 |
+| `--accumulation_steps` | Gradient accumulation steps | 1 |
+| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
+| `--warmup_steps` | LR warmup steps | 1000 |
+| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
+| `--ckpt_dir` | Checkpoint directory | checkpoint |
+| `--num_workers` | DataLoader workers | 4 |
+| `--nprocs` | Number of GPUs | 1 |
+
+Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
+
#### Docker
Build and run with Docker (recommended for GPU environments):
diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md
index 72e3d08..ff4dc99 100644
--- a/assets/docs/README-zh-CN.md
+++ b/assets/docs/README-zh-CN.md
@@ -76,7 +76,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
- --param_path=/path/to/param_path
+ --param_path=/path/to/model \
+ --n_epoch=3 \
+ --batch_size=4 \
+ --accumulation_steps=8 \
+ --max_lr=3e-4 \
+ --warmup_steps=2000 \
+ --ckpt_interval=5000 \
+ --ckpt_dir=./checkpoints
```
#### 文本生成
@@ -85,6 +92,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
+#### 训练参数
+
+| 参数 | 说明 | 默认值 |
+|------|------|--------|
+| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 |
+| `--data_root_path` | 数据集根目录 | 必填 |
+| `--param_path` | 模型参数或断点路径 | 必填 |
+| `--n_epoch` | 训练轮数 | 1 |
+| `--batch_size` | 批次大小 | 1 |
+| `--accumulation_steps` | 梯度累积步数 | 1 |
+| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
+| `--warmup_steps` | 预热步数 | 1000 |
+| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
+| `--ckpt_dir` | 检查点保存目录 | checkpoint |
+| `--num_workers` | 数据加载线程数 | 4 |
+| `--nprocs` | GPU 数量 | 1 |
+
+完整参数列表见[参数说明](./params.md#training-parameters)。
+
#### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境):
diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md
index f26b970..0937d96 100644
--- a/assets/docs/dataflow.md
+++ b/assets/docs/dataflow.md
@@ -176,7 +176,7 @@ flowchart LR
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
-- **`RadixNode`**: Tree node structure for prefix caching
+- **`_RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 6.3 Server (`server.py`)
diff --git a/assets/docs/design.md b/assets/docs/design.md
index 7817e8b..67e9f2f 100644
--- a/assets/docs/design.md
+++ b/assets/docs/design.md
@@ -85,8 +85,8 @@ classDiagram
}
class BaseSegmentFetcher {
- +List~Tensor~ segments
- +List~int~ cum_lengths
+ +List[Tensor] segments
+ +List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
@@ -191,7 +191,7 @@ classDiagram
+int dim
+int max_len
+float base
- +forward(x, start_pos) Tuple~Tensor, Tensor~
+ +forward(x, start_pos) Tuple[Tensor, Tensor]
}
class Embedding {
@@ -202,14 +202,14 @@ classDiagram
namespace tokenize {
class AutoTokenizer {
- +List~str~ stop_ids
+ +List[str] stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int
- +encode(tokens, out_ids, add_special_tokens) List~int~
+ +encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
- +apply_chat_template(messages, tokenize) Union~str, List[int]~
+ +apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
@@ -228,7 +228,7 @@ classDiagram
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
- +list_names() List~str~
+ +list_names() List[str]
}
class BaseFactory {
@@ -242,10 +242,10 @@ classDiagram
namespace trainer {
class Trainer {
+TrainConfig train_config
- +List~TrainCallback~ callbacks
+ +List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
- +_get_default_callbacks() List~TrainCallback~
+ +_get_default_callbacks() List[TrainCallback]
}
class TrainContext {
@@ -308,7 +308,7 @@ classDiagram
}
class BaseScheduler {
- +get_lr() List~float~
+ +get_lr() List[float]
+step()
}
@@ -390,10 +390,8 @@ classDiagram
+InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
- +int max_prefix_len
+ +int max_prompt_len
+int cache_capacity
- +Tensor kv_cache
- +Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict
@@ -403,10 +401,10 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
- +ModelConfig config
+Tuple kv_cache
+Tensor seq_mask
+PrefixCacheManager prefix_cache
+ +SlotAllocator slot_allocator
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@@ -417,21 +415,22 @@ classDiagram
}
class PrefixCacheManager {
- +RadixNode root
+ +_RadixNode root
+int max_capacity
- +List lru
- +insert(token_ids, slot)
- +find_longest_prefix(token_ids) Tuple[int, int]
+ +OrderedDict _lru
+ +insert(token_ids, slot, slot_ver)
+ +find(token_ids) Tuple[int, int, int]
+ +pin(token_ids)
+release(token_ids)
+ +copy_kv(token_ids, target_slot, kv_cache, n_layers)
}
- class RadixNode {
+ class _RadixNode {
+Dict children
- +int hash
+int slot
+ +int slot_ver
+int ref_count
+float last_access
- +List token_sequence
}
class Task {
@@ -446,15 +445,69 @@ classDiagram
+int input_tokens
+int output_tokens
+int slot
+ +int prefix_len
+ +float arrival_time
+ +float finish_time
+Callable stream_callback
+is_finished(stop_ids) bool
}
class TaskStatus {
- +str PENDING
- +str RUNNING
- +str FINISHED
- +str ABORTED
+ <>
+ PENDING
+ RUNNING
+ FINISHED
+ ABORTED
+ }
+
+ class GenerationRequest {
+ +List[Dict] messages
+ +GenerationParams params
+ +bool stream
+ }
+
+ class GenerationParams {
+ <>
+ +int top_k
+ +float top_p
+ +float temperature
+ +int max_tokens
+ }
+
+ class SlotAllocator {
+ +int _max_slots
+ +int _free_mask
+ +List _versions
+ +alloc() int
+ +free(idx)
+ +occupy(idx)
+ +is_free(idx) bool
+ +version(idx) int
+ }
+
+ class BaseSamplingStrategy {
+ <>
+ +apply(logits, filter_value) Tensor
+ }
+
+ class TemperatureStrategy {
+ +float temperature
+ +apply(logits, filter_value) Tensor
+ }
+
+ class TopKStrategy {
+ +int top_k
+ +apply(logits, filter_value) Tensor
+ }
+
+ class TopPStrategy {
+ +float top_p
+ +apply(logits, filter_value) Tensor
+ }
+
+ class SamplingPipeline {
+ +List strategies
+ +apply(logits, filter_value) Tensor
}
class Server {
@@ -462,21 +515,12 @@ classDiagram
+predict(request)
}
- class GenerationRequest {
- +int top_k
- +float top_p
- +float temperature
- +int max_len
- +List~Dict~ messages
- +stream bool
- }
-
class _Result {
- +List~str~ tokens
- +List~str~ results
- +List~bool~ done_flags
+ +List[str] tokens
+ +List[str] results
+ +List[bool] done_flags
+append(token, idx)
- +get_results() List~str~
+ +get_results() List[str]
}
class ChatMessage {
@@ -485,13 +529,13 @@ classDiagram
}
class ChatCompletionRequest {
- +List~ChatMessage~ messages
+ +List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
- +Optional~str~ system_prompt
+ +Optional[str] system_prompt
}
class CompletionResponse {
@@ -499,7 +543,7 @@ classDiagram
+str object
+int created
+str model
- +List~Dict~ choices
+ +List[Dict] choices
}
}
@@ -542,7 +586,6 @@ classDiagram
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
- AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
@@ -553,11 +596,19 @@ classDiagram
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses
+ InferenceEngine --> GenerationRequest : uses
+ GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages
+ Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
+ InferenceScheduler --> SlotAllocator : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
- InferenceEngine --> GenerationRequest : uses
+ InferenceEngine --> _Result : uses
+ BaseSamplingStrategy <|-- TemperatureStrategy
+ BaseSamplingStrategy <|-- TopKStrategy
+ BaseSamplingStrategy <|-- TopPStrategy
+ SamplingPipeline --> BaseSamplingStrategy : composes
Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
@@ -585,7 +636,7 @@ classDiagram
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
- InferenceScheduler --> RadixNode : uses
+ PrefixCacheManager --> _RadixNode : composes
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
@@ -606,7 +657,7 @@ classDiagram
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
-| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
+| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
@@ -620,6 +671,8 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
+| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
+| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
@@ -630,7 +683,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
-4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming
+4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
diff --git a/assets/docs/params.md b/assets/docs/params.md
index 0ad17c8..26e772a 100644
--- a/assets/docs/params.md
+++ b/assets/docs/params.md
@@ -4,70 +4,83 @@
### Basic Parameters
-| Parameter | Description | Default Value |
-|-----------|-------------|---------------|
-| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
-| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
+| Parameter | Description | Default |
+|-----------|-------------|---------|
+| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
-| `--batch_size` | Batch size | 4 |
-| `--accumulation_steps` | Gradient accumulation steps | 1 |
+| `--batch_size` | Batch size | 1 |
+| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
-| Parameter | Description | Default Value |
-|-----------|-------------|---------------|
+| Parameter | Description | Default |
+|-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 |
-| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
-| `--max_grad_norm` | Maximum gradient norm | 1.0 |
+| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
+| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
-### Checkpoint
+### Optimizer (AdamW)
-| Parameter | Description | Default Value |
-|-----------|-------------|---------------|
-| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
-| `--ckpt_dir` | Checkpoint save directory | checkpoint |
-| `--resume_dir` | Resume training from specified path | - |
-
-### Optimizer Parameters
-
-| Parameter | Description | Default Value |
-|-----------|-------------|---------------|
+| Parameter | Description | Default |
+|-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
-| Parameter | Description | Default Value |
-|-----------|-------------|---------------|
-| `--random_seed` | Random seed | 3407 |
-| `--num_workers` | DataLoader workers | 0 |
-| `--prefetch_factor` | Prefetch factor for dataloader | None |
-| `--pin_memory` | Enable pin_memory | False |
-| `--no_pin_memory` | Disable pin_memory | - |
+| Parameter | Description | Default |
+|-----------|-------------|---------|
+| `--window_size` | Max input sequence length | model config `max_len` |
+| `--stride` | Stride for sliding window over sequences | None |
+| `--random_seed` | Random seed for reproducibility | 3407 |
+| `--num_workers` | DataLoader worker processes | 4 |
+| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
+
+### Checkpoint & Resume
+
+| Parameter | Description | Default |
+|-----------|-------------|---------|
+| `--ckpt_interval` | Iterations between checkpoints | 5000 |
+| `--ckpt_dir` | Checkpoint save directory | checkpoint |
+| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
+| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training
-| Parameter | Description | Default Value |
-|-----------|-------------|---------------|
-| `--nprocs` | Number of GPUs | 1 |
-| `--device_type` | Device type (cuda/cpu) | cuda |
+| Parameter | Description | Default |
+|-----------|-------------|---------|
+| `--nprocs` | Number of GPUs / processes | 1 |
+| `--device_type` | Device type | cuda |
-### Other Parameters
+### Strategy-specific
-| Parameter | Description | Default Value |
-|-----------|-------------|---------------|
-| `--window_size` | Maximum input sequence length | model config max_len |
-| `--stride` | Input sequence stride | - |
-| `--dpo_beta` | DPO beta value | 0.1 |
-| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
-| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
-| `--grpo_group_size` | GRPO group size | 4 |
-| `--label_smoothing` | Label smoothing parameter | 0.1 |
-| `--start_epoch` | Starting epoch | 0 |
-| `--start_batch` | Starting batch | 0 |
+| Parameter | Description | Default | Used by |
+|-----------|-------------|---------|---------|
+| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
+| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
+
+### Usage Example
+
+```bash
+python scripts/tools/train.py \
+ --train_type seq \
+ --data_root_path /path/to/dataset \
+ --param_path /path/to/model \
+ --n_epoch 3 \
+ --batch_size 4 \
+ --accumulation_steps 8 \
+ --max_lr 3e-4 \
+ --warmup_steps 2000 \
+ --max_grad_norm 1.0 \
+ --ckpt_interval 5000 \
+ --ckpt_dir ./checkpoints \
+ --num_workers 4 \
+ --nprocs 1 \
+ --device_type cuda
+```
---
@@ -89,14 +102,14 @@
```python
import torch
from astrai.model import AutoModel
-from astrai.tokenize import Tokenizer
+from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
-tokenizer = Tokenizer("your_model_dir")
+tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(