docs : 三轮深度验证修复文档与代码不一致

- architecture.md: 修正 unwrap_model 返回类型、Config Optional 标注、方法签名错误、类名错误
- training.md: 补充 on_error 回调、修正训练循环顺序、补全策略参数、model.safetensors
- inference.md: 修正 GenerationRequest 参数顺序、async 语法、KVCache 描述、temperature 约束
- dataflow.md: 补充 Store.load/fetch 流程、修正可选参数默认值
- README/params: 多 GPU 示例补全 --parallel_mode、文档表补充 preprocessing.md
- preprocessing.md: Chat 模式算法补全 BOS token 步骤
This commit is contained in:
ViperEkura 2026-05-30 21:40:25 +08:00
parent 31ae2deeba
commit 1c2ff05a6d
8 changed files with 219 additions and 91 deletions

View File

@ -82,6 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \
@ -108,8 +109,8 @@ Full reference at [Parameter Guide](assets/docs/params.md).
```bash ```bash
python scripts/tools/generate.py \ python scripts/tools/generate.py \
--param_path /path/to/model \ --param_path /path/to/model \
--input_json_file /path/to/input.json \ --input_json_file /path/to/input.jsonl \
--output_json_file /path/to/output.json --output_json_file /path/to/output.jsonl
``` ```
#### Docker #### Docker
@ -224,6 +225,7 @@ Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6y
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas | | [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API | | [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture | | [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
### Contributing ### Contributing

View File

@ -88,6 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \
@ -114,8 +115,8 @@ nohup python scripts/tools/train.py \
```bash ```bash
python scripts/tools/generate.py \ python scripts/tools/generate.py \
--param_path /path/to/model \ --param_path /path/to/model \
--input_json_file /path/to/input.json \ --input_json_file /path/to/input.jsonl \
--output_json_file /path/to/output.json --output_json_file /path/to/output.jsonl
``` ```
#### Docker #### Docker
@ -230,6 +231,7 @@ python scripts/demo/generate_ar.py
| [训练文档](./training.md) | 训练循环、策略与公式 | | [训练文档](./training.md) | 训练循环、策略与公式 |
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API | | [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 | | [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
### 贡献 ### 贡献

View File

@ -8,6 +8,8 @@ classDiagram
class BaseConfig { class BaseConfig {
+to_dict() Dict +to_dict() Dict
+from_dict(d) Self +from_dict(d) Self
+from_json(path) Self
+to_json(path)
} }
class BaseModelConfig { class BaseModelConfig {
@ -17,42 +19,42 @@ classDiagram
} }
class AutoRegressiveLMConfig { class AutoRegressiveLMConfig {
+int vocab_size +Optional[int] vocab_size
+int dim +Optional[int] dim
+int n_layers +Optional[int] n_layers
+float norm_eps +Optional[float] norm_eps
+int dim_ffn +Optional[int] dim_ffn
+Optional[bool] tie_weight +Optional[bool] tie_weight
+Optional[dict] rope_scaling +Optional[dict] rope_scaling
+int max_len +Optional[int] max_len
+float rope_theta +Optional[float] rope_theta
+str attn_type +str attn_type
+int n_heads +Optional[int] n_heads
+int n_kv_heads +Optional[int] n_kv_heads
+bool use_qk_norm +Optional[bool] use_qk_norm
+bool use_gated_attention +Optional[bool] use_gated_attention
+Optional[int] kv_lora_rank +Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim +Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim +Optional[int] qk_rope_head_dim
+str ffn_type +str ffn_type
+int n_routed_experts +Optional[int] n_routed_experts
+int n_shared_experts +Optional[int] n_shared_experts
+int n_activated_experts +Optional[int] n_activated_experts
+Optional[str] topk_method +Optional[str] topk_method
} }
class EncoderConfig { class EncoderConfig {
+int vocab_size +Optional[int] vocab_size
+int dim +Optional[int] dim
+int n_layers +Optional[int] n_layers
+float norm_eps +Optional[float] norm_eps
+int dim_ffn +Optional[int] dim_ffn
+int max_len +Optional[int] max_len
+float rope_theta +Optional[float] rope_theta
+int n_heads +Optional[int] n_heads
+int n_kv_heads +Optional[int] n_kv_heads
+bool use_qk_norm +Optional[bool] use_qk_norm
+bool use_gated_attention +Optional[bool] use_gated_attention
+Optional[dict] rope_scaling +Optional[dict] rope_scaling
+Optional[str] pooling_type +Optional[str] pooling_type
+Optional[bool] normalize_embeddings +Optional[bool] normalize_embeddings
@ -64,6 +66,38 @@ classDiagram
+load(raw) BaseConfig +load(raw) BaseConfig
} }
class InputConfig {
+str type
+str messages_key
+str prompt_key
+str response_key
+str text_key
}
class ProcessingConfig {
+int max_seq_len
+int min_chars
+int max_chars
+bool deduplicate
+Optional[int] max_items
}
class OutputConfig {
+Optional[str] domain_key
+str storage_format
+int max_tokens_per_shard
}
class PipelineConfig {
+int version
+InputConfig input
+dict mask
+str mask_default
+ProcessingConfig preprocessing
+OutputConfig output
+from_dict(d) Self
}
class TrainConfig { class TrainConfig {
+Callable[[], nn.Module] model_fn +Callable[[], nn.Module] model_fn
+str strategy +str strategy
@ -312,10 +346,29 @@ classDiagram
} }
} }
namespace preprocessing {
class BaseMaskBuilder {
<<abstract>>
+build(item, config, tokenizer) Optional[dict]
}
class ChatMaskBuilder {
+build(item, config, tokenizer) Optional[dict]
}
class InstructionMaskBuilder {
+build(item, config, tokenizer) Optional[dict]
}
class TextMaskBuilder {
+build(item, config, tokenizer) Optional[dict]
}
}
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+vocab_size int +vocab_size int
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int] +encode(tokens, out_ids, is_pretokenized, add_special_tokens) List
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]] +apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
@ -346,14 +399,20 @@ classDiagram
+create(name, *args, **kwargs) T +create(name, *args, **kwargs) T
+list_registered() list +list_registered() list
} }
class MaskBuilderFactory {
+Registry _registry
+register(name) decorator
+create(input_type, config, tokenizer) BaseMaskBuilder
}
} }
namespace trainer { namespace trainer {
class Trainer { class Trainer {
+TrainConfig train_config +TrainConfig train_config
+List[TrainCallback] callbacks +List[TrainCallback] callbacks
+train(checkpoint) +train(resume_dir)
+_get_default_callbacks() List[TrainCallback] -_get_default_callbacks() List[TrainCallback]
} }
class TrainContext { class TrainContext {
@ -383,8 +442,12 @@ classDiagram
} }
class BaseStrategy { class BaseStrategy {
+Union[Callable, nn.Module] model +Callable model
+Optional[BaseExecutor] executor
+Optional[Callable] model_fn
+dict extra_kwargs
+str device +str device
+__call__(batch) Tensor
+compute_loss(batch) Tensor +compute_loss(batch) Tensor
} }
@ -425,6 +488,8 @@ classDiagram
class BaseScheduler { class BaseScheduler {
+get_lr() List[float] +get_lr() List[float]
+step() +step()
+state_dict() dict
+load_state_dict(d)
} }
class SchedulerFactory { class SchedulerFactory {
@ -436,6 +501,7 @@ classDiagram
class CosineScheduler { class CosineScheduler {
+int warmup_steps +int warmup_steps
+int lr_decay_steps +int lr_decay_steps
+int total_steps
+float min_rate +float min_rate
} }
@ -474,11 +540,11 @@ classDiagram
+int interval +int interval
+bool weight_only +bool weight_only
+Callable save_extra_fn +Callable save_extra_fn
+_save_checkpoint(context) -_save_checkpoint(context)
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context) +on_error(context)
+save_extra(context)$ +save_extra(context) dict$
} }
class ProgressBarCallback { class ProgressBarCallback {
@ -491,7 +557,7 @@ classDiagram
} }
class MetricLoggerCallback { class MetricLoggerCallback {
+str log_dir +Path log_dir
+int save_interval +int save_interval
+int log_interval +int log_interval
+List[str] metrics +List[str] metrics
@ -501,7 +567,7 @@ classDiagram
} }
class ValidationCallback { class ValidationCallback {
+_run_validation(context) -_run_validation(context)
+on_optimizer_step(context) +on_optimizer_step(context)
} }
@ -517,7 +583,7 @@ classDiagram
+float weight_decay +float weight_decay
+bool nesterov +bool nesterov
+int ns_steps +int ns_steps
+float adamw_lr +Optional[float] adamw_lr
+tuple adamw_betas +tuple adamw_betas
+float adamw_eps +float adamw_eps
+float adamw_wd +float adamw_wd
@ -634,7 +700,7 @@ classDiagram
class Task { class Task {
+str task_id +str task_id
+List prompt_ids +List prompt_ids
+int max_tokens +Optional[int] max_tokens
+float temperature +float temperature
+float top_p +float top_p
+int top_k +int top_k
@ -643,8 +709,8 @@ classDiagram
+int input_tokens +int input_tokens
+int output_tokens +int output_tokens
+float arrival_time +float arrival_time
+float finish_time +Optional[float] finish_time
+Callable stream_callback +Optional[Callable] stream_callback
+int next_pos +int next_pos
+is_finished(stop_ids) bool +is_finished(stop_ids) bool
} }
@ -671,6 +737,11 @@ classDiagram
+activate(task) +activate(task)
+return_to_waiting(tasks) +return_to_waiting(tasks)
+get_active_tasks() List[Task] +get_active_tasks() List[Task]
+has_work() bool
+wait_for_tasks(timeout)
+get_waiting_tasks() List[Task]
+clear_queues()
+wake()
+get_stats() Dict +get_stats() Dict
} }
@ -760,7 +831,7 @@ classDiagram
class ResponseBuilder { class ResponseBuilder {
<<abstract>> <<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]] +prepare(request, tokenizer) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str] +format_stream_start(ctx) List[str]
+format_chunk(token) str +format_chunk(token) str
+format_stream_end(ctx, stop) List[str] +format_stream_end(ctx, stop) List[str]
@ -768,7 +839,7 @@ classDiagram
} }
class OpenAIResponseBuilder { class OpenAIResponseBuilder {
+prepare(request, engine) Tuple +prepare(request, tokenizer) Tuple
+format_stream_start(ctx) List[str] +format_stream_start(ctx) List[str]
+format_chunk(token) str +format_chunk(token) str
+format_stream_end(ctx, stop) List[str] +format_stream_end(ctx, stop) List[str]
@ -776,7 +847,7 @@ classDiagram
} }
class AnthropicResponseBuilder { class AnthropicResponseBuilder {
+prepare(request, engine) Tuple +prepare(request, tokenizer) Tuple
+format_stream_start(ctx) List[str] +format_stream_start(ctx) List[str]
+format_chunk(token) str +format_chunk(token) str
+format_stream_end(ctx, stop) List[str] +format_stream_end(ctx, stop) List[str]
@ -787,12 +858,13 @@ classDiagram
+request +request
+engine +engine
+builder: ResponseBuilder +builder: ResponseBuilder
+handle() Union[StreamingResponse, Dict] +async handle() Union[StreamingResponse, Dict]
-_handle_stream(agen, ctx, stops) StreamingResponse -_handle_stream(agen, ctx, stop_sequences) StreamingResponse
-_handle_non_stream(agen, ctx, stops) Dict -async _handle_non_stream(agen, ctx, stop_sequences) Dict
} }
class StopChecker { class StopChecker {
+__init__(sequences)
+check(text) Optional[str] +check(text) Optional[str]
} }
@ -804,6 +876,12 @@ classDiagram
+int completion_tokens +int completion_tokens
} }
class StopInfo {
+Optional[str] matched
+str body
+str yielded
}
class app { class app {
<<singleton>> <<singleton>>
+FastAPI app +FastAPI app
@ -829,14 +907,14 @@ classDiagram
} }
namespace parallel { namespace parallel {
class Functions { class setup {
<<module>> <<module>>
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs) +spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) contextmanager
+get_current_device() str +get_current_device() str
+get_world_size() int +get_world_size() int
+get_rank() int +get_rank() int
+only_on_rank(rank, sync) decorator +only_on_rank(rank, sync=False) decorator
} }
class GradientState { class GradientState {
@ -847,6 +925,7 @@ classDiagram
class AccumOptimizer { class AccumOptimizer {
+Optimizer optimizer +Optimizer optimizer
+GradientState gradient_state +GradientState gradient_state
+param_groups (property)
+step(closure) +step(closure)
+zero_grad() +zero_grad()
+state_dict() dict +state_dict() dict
@ -867,7 +946,7 @@ classDiagram
+prepare(model, optimizer, dataloader, scheduler) tuple +prepare(model, optimizer, dataloader, scheduler) tuple
+accumulate(model) context manager +accumulate(model) context manager
+backward(loss) +backward(loss)
+unwrap_model(model) nn.Module +unwrap_model(model) dict
+sync_gradients (property) bool +sync_gradients (property) bool
+grad_accum_steps (property) int +grad_accum_steps (property) int
} }
@ -876,14 +955,14 @@ classDiagram
} }
class DDPExecutor { class DDPExecutor {
+_prepare_model(model) nn.Module -_prepare_model(model) nn.Module
+_no_sync(model) context manager -_no_sync(model) context manager
+unwrap_model(model) nn.Module +unwrap_model(model) dict
} }
class FSDPExecutor { class FSDPExecutor {
+_prepare_model(model) nn.Module -_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module +unwrap_model(model) dict
} }
class ExecutorFactory { class ExecutorFactory {
@ -899,11 +978,25 @@ classDiagram
} }
class ColumnParallelLinear { class ColumnParallelLinear {
+int in_features
+int out_features
+int out_features_per_rank
+bool gather_results
+Parameter weight
+Optional[Parameter] bias
+forward(x) Tensor +forward(x) Tensor
+load_state_dict(state_dict)
} }
class RowParallelLinear { class RowParallelLinear {
+int in_features
+int out_features
+int in_features_per_rank
+bool reduce_results
+Parameter weight
+Optional[Parameter] bias
+forward(x) Tensor +forward(x) Tensor
+load_state_dict(state_dict)
} }
} }
@ -938,6 +1031,10 @@ classDiagram
AutoModel <|-- EmbeddingEncoder AutoModel <|-- EmbeddingEncoder
BaseConfig <|-- BaseModelConfig BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig BaseConfig <|-- TrainConfig
BaseConfig <|-- InputConfig
BaseConfig <|-- ProcessingConfig
BaseConfig <|-- OutputConfig
BaseConfig <|-- PipelineConfig
BaseModelConfig <|-- AutoRegressiveLMConfig BaseModelConfig <|-- AutoRegressiveLMConfig
BaseModelConfig <|-- EncoderConfig BaseModelConfig <|-- EncoderConfig
BaseFactory <|-- AutoModel BaseFactory <|-- AutoModel
@ -950,11 +1047,15 @@ classDiagram
BaseFactory <|-- StoreFactory BaseFactory <|-- StoreFactory
BaseFactory <|-- ExecutorFactory BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory BaseFactory <|-- ConfigFactory
BaseFactory <|-- MaskBuilderFactory
BaseExecutor <|-- NoneExecutor BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor BaseExecutor <|-- DDPExecutor
BaseExecutor <|-- FSDPExecutor BaseExecutor <|-- FSDPExecutor
ResponseBuilder <|-- OpenAIResponseBuilder ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder ResponseBuilder <|-- AnthropicResponseBuilder
BaseMaskBuilder <|-- ChatMaskBuilder
BaseMaskBuilder <|-- InstructionMaskBuilder
BaseMaskBuilder <|-- TextMaskBuilder
%% --- Composition (strong ownership, part destroyed with whole) --- %% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool KVCache *-- PagePool
@ -994,6 +1095,8 @@ classDiagram
%% --- Dependency (uses temporarily) --- %% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects TrainConfig ..> BaseStrategy : selects
PipelineConfig ..> MaskBuilderFactory : selects
MaskBuilderFactory ..> BaseMaskBuilder : creates
StrategyFactory ..> BaseStrategy : creates StrategyFactory ..> BaseStrategy : creates
SchedulerFactory ..> BaseScheduler : creates SchedulerFactory ..> BaseScheduler : creates
DatasetFactory ..> BaseDataset : creates DatasetFactory ..> BaseDataset : creates
@ -1046,7 +1149,8 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, ChatMaskBuilder, InstructionMaskBuilder, TextMaskBuilder, Pipeline, filter_by_length, dedup_signature | Declarative JSON-driven data preprocessing |
| **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization | | **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
@ -1070,14 +1174,14 @@ classDiagram
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor`, `FSDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
## Core Relationships ## Core Relationships
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` 1. **Config → Training**: `TrainConfig` holds `model_fn`, `dataset`, `optimizer_fn`, `scheduler_fn`, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
@ -1089,4 +1193,4 @@ classDiagram
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-30

View File

@ -5,7 +5,7 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview ## Overview
``` ```
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
``` ```
## Data Preparation ## Data Preparation
@ -33,14 +33,21 @@ H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS pag
## Dataset Architecture ## Dataset Architecture
``` ```
DatasetFactory.load(train_type, load_path, window_size, stride, storage_type) DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
→ StoreFactory.create(detect_format(path)) → BaseDataset.load(load_path, storage_type=None)
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] → detect_format(load_path)
→ BaseDataset.__getitem__(idx) → StoreFactory.create(storage_type)
→ sliding window [begin, end) via get_index(idx) → Store.load(load_path)
→ H5Store._normalize() / MmapStore._normalize()
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx)
→ get_index(idx) → [begin, end)
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
``` ```
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`). `window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
## Sampler ## Sampler
@ -54,4 +61,4 @@ DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-30

View File

@ -12,7 +12,7 @@ RoPE is applied **before** KV cache write, not after — otherwise position enco
## KVCache System ## KVCache System
Six classes working together: Six classes (plus two helpers) working together:
``` ```
KVCache (facade) KVCache (facade)
@ -43,7 +43,8 @@ KVCache (facade)
BaseSamplingStrategy (ABC) BaseSamplingStrategy (ABC)
├── TemperatureStrategy ├── TemperatureStrategy
├── TopKStrategy ├── TopKStrategy
└── TopPStrategy ├── TopPStrategy
└── SamplingPipeline
``` ```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
@ -73,7 +74,9 @@ Adding a protocol = one builder file, no handler subclassing needed.
InferenceEngine InferenceEngine
├── generate(prompt, stream, ...) → str | List[str] | Generator ├── generate(prompt, stream, ...) → str | List[str] | Generator
├── generate_with_request(req) → same ├── generate_with_request(req) → same
└── generate_async(prompt, ...) → AsyncGenerator ├── generate_async(prompt, ...) → AsyncGenerator
├── get_stats() → Dict
└── shutdown()
``` ```
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`. `GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
@ -124,9 +127,9 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description | | Param | Type | Default | Description |
|-------|------|---------|-------------| |-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) | | `messages` | List[dict] | required | Chat messages (role, content) |
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count | | `top_k` | int | 50 | Top-k count |
| `top_p` | float | 1.0 | Nucleus threshold |
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
| `max_tokens` | Optional[int] | None | Max generation length | | `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output | | `stream` | bool | False | Stream output |
@ -142,7 +145,8 @@ engine.generate("Hello", stream=True) # -> Generator[str]
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]] engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
# Async # Async
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
print(token)
``` ```
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-30

View File

@ -75,6 +75,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \

View File

@ -147,10 +147,11 @@ For instruction mode, keys are `"prompt"` and `"response"`.
For each message in the `messages` array: For each message in the `messages` array:
1. Render through the chat template for that single message 1. Prepend BOS token (position 0, always masked)
2. Encode the rendered text, record token span `(start, end, role)` 2. Render through the chat template for that single message
3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries 3. Encode the rendered text, record token span `(start, end, role)`
4. Fill `loss_mask` from the mask rules 4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
5. Fill `loss_mask` from the mask rules
**Multi-turn example**: **Multi-turn example**:

View File

@ -36,14 +36,16 @@ Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_
``` ```
on_train_begin on_train_begin
model.train()
on_epoch_begin on_epoch_begin
for batch in dataloader: for batch in dataloader:
on_batch_begin on_batch_begin
with executor.accumulate(model): with executor.accumulate(model):
loss = strategy(batch) loss = strategy.compute_loss(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss) executor.backward(stand_loss)
iteration += 1 context.iteration += 1
on_batch_end on_batch_end
if executor.sync_gradients: if executor.sync_gradients:
@ -61,9 +63,13 @@ on_train_end
| Hook | Fires | Default callback | | Hook | Fires | Default callback |
|------|-------|-----------------| |------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` | | `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
| `on_batch_begin` | Every batch | — |
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` | | `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | | `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset). Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
@ -77,7 +83,7 @@ $$
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta) L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$ $$
Keys: `input_ids`, `target_ids` Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
### SFT (Supervised Fine-Tuning) ### SFT (Supervised Fine-Tuning)
@ -87,7 +93,7 @@ $$
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta) L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$ $$
Keys: `input_ids`, `target_ids`, `loss_mask` Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
### DPO (Direct Preference Optimization) ### DPO (Direct Preference Optimization)
@ -97,7 +103,7 @@ $$
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right] L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
$$ $$
Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`. Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
### GRPO (Group Relative Policy Optimization) ### GRPO (Group Relative Policy Optimization)
@ -111,7 +117,7 @@ $$
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right] L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
$$ $$
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`. Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
Keys: `prompts`, `responses`, `masks`, `rewards`. Keys: `prompts`, `responses`, `masks`, `rewards`.
@ -122,7 +128,7 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` | | Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) | | SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
## Gradient Checkpointing ## Gradient Checkpointing
@ -139,8 +145,8 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
``` ```
Checkpoint(state_dict, epoch, iteration, extra, meta, config) Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt) ├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir) broadcasts metadata from rank-0 └── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
``` ```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`. Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
@ -161,7 +167,7 @@ context = (
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` - Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers - Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume - Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)` - Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
## Training CLI ## Training CLI
@ -170,6 +176,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \
@ -191,4 +198,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md). Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-30