diff --git a/README.md b/README.md index 064c3f6..676722c 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ + --parallel_mode=ddp \ --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ @@ -108,8 +109,8 @@ Full reference at [Parameter Guide](assets/docs/params.md). ```bash python scripts/tools/generate.py \ --param_path /path/to/model \ - --input_json_file /path/to/input.json \ - --output_json_file /path/to/output.json + --input_json_file /path/to/input.jsonl \ + --output_json_file /path/to/output.jsonl ``` #### Docker @@ -224,6 +225,7 @@ Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6y | [Training](./assets/docs/training.md) | Training loop, strategies & formulas | | [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API | | [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture | +| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing | ### Contributing diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index 0fb7789..5d336eb 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -88,6 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ + --parallel_mode=ddp \ --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ @@ -114,8 +115,8 @@ nohup python scripts/tools/train.py \ ```bash python scripts/tools/generate.py \ --param_path /path/to/model \ - --input_json_file /path/to/input.json \ - --output_json_file /path/to/output.json + --input_json_file /path/to/input.jsonl \ + --output_json_file /path/to/output.jsonl ``` #### Docker @@ -230,6 +231,7 @@ python scripts/demo/generate_ar.py | [训练文档](./training.md) | 训练循环、策略与公式 | | [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API | | [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 | +| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 | ### 贡献 diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index d154518..e42ac2e 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -8,6 +8,8 @@ classDiagram class BaseConfig { +to_dict() Dict +from_dict(d) Self + +from_json(path) Self + +to_json(path) } class BaseModelConfig { @@ -17,42 +19,42 @@ classDiagram } class AutoRegressiveLMConfig { - +int vocab_size - +int dim - +int n_layers - +float norm_eps - +int dim_ffn + +Optional[int] vocab_size + +Optional[int] dim + +Optional[int] n_layers + +Optional[float] norm_eps + +Optional[int] dim_ffn +Optional[bool] tie_weight +Optional[dict] rope_scaling - +int max_len - +float rope_theta + +Optional[int] max_len + +Optional[float] rope_theta +str attn_type - +int n_heads - +int n_kv_heads - +bool use_qk_norm - +bool use_gated_attention + +Optional[int] n_heads + +Optional[int] n_kv_heads + +Optional[bool] use_qk_norm + +Optional[bool] use_gated_attention +Optional[int] kv_lora_rank +Optional[int] qk_nope_head_dim +Optional[int] qk_rope_head_dim +str ffn_type - +int n_routed_experts - +int n_shared_experts - +int n_activated_experts + +Optional[int] n_routed_experts + +Optional[int] n_shared_experts + +Optional[int] n_activated_experts +Optional[str] topk_method } class EncoderConfig { - +int vocab_size - +int dim - +int n_layers - +float norm_eps - +int dim_ffn - +int max_len - +float rope_theta - +int n_heads - +int n_kv_heads - +bool use_qk_norm - +bool use_gated_attention + +Optional[int] vocab_size + +Optional[int] dim + +Optional[int] n_layers + +Optional[float] norm_eps + +Optional[int] dim_ffn + +Optional[int] max_len + +Optional[float] rope_theta + +Optional[int] n_heads + +Optional[int] n_kv_heads + +Optional[bool] use_qk_norm + +Optional[bool] use_gated_attention +Optional[dict] rope_scaling +Optional[str] pooling_type +Optional[bool] normalize_embeddings @@ -64,6 +66,38 @@ classDiagram +load(raw) BaseConfig } + class InputConfig { + +str type + +str messages_key + +str prompt_key + +str response_key + +str text_key + } + + class ProcessingConfig { + +int max_seq_len + +int min_chars + +int max_chars + +bool deduplicate + +Optional[int] max_items + } + + class OutputConfig { + +Optional[str] domain_key + +str storage_format + +int max_tokens_per_shard + } + + class PipelineConfig { + +int version + +InputConfig input + +dict mask + +str mask_default + +ProcessingConfig preprocessing + +OutputConfig output + +from_dict(d) Self + } + class TrainConfig { +Callable[[], nn.Module] model_fn +str strategy @@ -312,10 +346,29 @@ classDiagram } } + namespace preprocessing { + class BaseMaskBuilder { + <> + +build(item, config, tokenizer) Optional[dict] + } + + class ChatMaskBuilder { + +build(item, config, tokenizer) Optional[dict] + } + + class InstructionMaskBuilder { + +build(item, config, tokenizer) Optional[dict] + } + + class TextMaskBuilder { + +build(item, config, tokenizer) Optional[dict] + } + } + namespace tokenize { class AutoTokenizer { +vocab_size int - +encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int] + +encode(tokens, out_ids, is_pretokenized, add_special_tokens) List +decode(tokens, skip_special_tokens) str +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) +apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]] @@ -346,14 +399,20 @@ classDiagram +create(name, *args, **kwargs) T +list_registered() list } + + class MaskBuilderFactory { + +Registry _registry + +register(name) decorator + +create(input_type, config, tokenizer) BaseMaskBuilder + } } namespace trainer { class Trainer { +TrainConfig train_config +List[TrainCallback] callbacks - +train(checkpoint) - +_get_default_callbacks() List[TrainCallback] + +train(resume_dir) + -_get_default_callbacks() List[TrainCallback] } class TrainContext { @@ -383,8 +442,12 @@ classDiagram } class BaseStrategy { - +Union[Callable, nn.Module] model + +Callable model + +Optional[BaseExecutor] executor + +Optional[Callable] model_fn + +dict extra_kwargs +str device + +__call__(batch) Tensor +compute_loss(batch) Tensor } @@ -425,6 +488,8 @@ classDiagram class BaseScheduler { +get_lr() List[float] +step() + +state_dict() dict + +load_state_dict(d) } class SchedulerFactory { @@ -436,6 +501,7 @@ classDiagram class CosineScheduler { +int warmup_steps +int lr_decay_steps + +int total_steps +float min_rate } @@ -474,11 +540,11 @@ classDiagram +int interval +bool weight_only +Callable save_extra_fn - +_save_checkpoint(context) + -_save_checkpoint(context) +on_batch_end(context) +on_train_end(context) +on_error(context) - +save_extra(context)$ + +save_extra(context) dict$ } class ProgressBarCallback { @@ -491,7 +557,7 @@ classDiagram } class MetricLoggerCallback { - +str log_dir + +Path log_dir +int save_interval +int log_interval +List[str] metrics @@ -501,7 +567,7 @@ classDiagram } class ValidationCallback { - +_run_validation(context) + -_run_validation(context) +on_optimizer_step(context) } @@ -517,7 +583,7 @@ classDiagram +float weight_decay +bool nesterov +int ns_steps - +float adamw_lr + +Optional[float] adamw_lr +tuple adamw_betas +float adamw_eps +float adamw_wd @@ -634,7 +700,7 @@ classDiagram class Task { +str task_id +List prompt_ids - +int max_tokens + +Optional[int] max_tokens +float temperature +float top_p +int top_k @@ -643,8 +709,8 @@ classDiagram +int input_tokens +int output_tokens +float arrival_time - +float finish_time - +Callable stream_callback + +Optional[float] finish_time + +Optional[Callable] stream_callback +int next_pos +is_finished(stop_ids) bool } @@ -671,6 +737,11 @@ classDiagram +activate(task) +return_to_waiting(tasks) +get_active_tasks() List[Task] + +has_work() bool + +wait_for_tasks(timeout) + +get_waiting_tasks() List[Task] + +clear_queues() + +wake() +get_stats() Dict } @@ -760,7 +831,7 @@ classDiagram class ResponseBuilder { <> - +prepare(request, engine) Tuple[str, GenContext, List[str]] + +prepare(request, tokenizer) Tuple[str, GenContext, List[str]] +format_stream_start(ctx) List[str] +format_chunk(token) str +format_stream_end(ctx, stop) List[str] @@ -768,7 +839,7 @@ classDiagram } class OpenAIResponseBuilder { - +prepare(request, engine) Tuple + +prepare(request, tokenizer) Tuple +format_stream_start(ctx) List[str] +format_chunk(token) str +format_stream_end(ctx, stop) List[str] @@ -776,7 +847,7 @@ classDiagram } class AnthropicResponseBuilder { - +prepare(request, engine) Tuple + +prepare(request, tokenizer) Tuple +format_stream_start(ctx) List[str] +format_chunk(token) str +format_stream_end(ctx, stop) List[str] @@ -787,12 +858,13 @@ classDiagram +request +engine +builder: ResponseBuilder - +handle() Union[StreamingResponse, Dict] - -_handle_stream(agen, ctx, stops) StreamingResponse - -_handle_non_stream(agen, ctx, stops) Dict + +async handle() Union[StreamingResponse, Dict] + -_handle_stream(agen, ctx, stop_sequences) StreamingResponse + -async _handle_non_stream(agen, ctx, stop_sequences) Dict } class StopChecker { + +__init__(sequences) +check(text) Optional[str] } @@ -804,6 +876,12 @@ classDiagram +int completion_tokens } + class StopInfo { + +Optional[str] matched + +str body + +str yielded + } + class app { <> +FastAPI app @@ -829,14 +907,14 @@ classDiagram } namespace parallel { - class Functions { + class setup { <> +spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs) - +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) + +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) contextmanager +get_current_device() str +get_world_size() int +get_rank() int - +only_on_rank(rank, sync) decorator + +only_on_rank(rank, sync=False) decorator } class GradientState { @@ -847,6 +925,7 @@ classDiagram class AccumOptimizer { +Optimizer optimizer +GradientState gradient_state + +param_groups (property) +step(closure) +zero_grad() +state_dict() dict @@ -867,7 +946,7 @@ classDiagram +prepare(model, optimizer, dataloader, scheduler) tuple +accumulate(model) context manager +backward(loss) - +unwrap_model(model) nn.Module + +unwrap_model(model) dict +sync_gradients (property) bool +grad_accum_steps (property) int } @@ -876,14 +955,14 @@ classDiagram } class DDPExecutor { - +_prepare_model(model) nn.Module - +_no_sync(model) context manager - +unwrap_model(model) nn.Module + -_prepare_model(model) nn.Module + -_no_sync(model) context manager + +unwrap_model(model) dict } class FSDPExecutor { - +_prepare_model(model) nn.Module - +unwrap_model(model) nn.Module + -_prepare_model(model) nn.Module + +unwrap_model(model) dict } class ExecutorFactory { @@ -899,11 +978,25 @@ classDiagram } class ColumnParallelLinear { + +int in_features + +int out_features + +int out_features_per_rank + +bool gather_results + +Parameter weight + +Optional[Parameter] bias +forward(x) Tensor + +load_state_dict(state_dict) } class RowParallelLinear { + +int in_features + +int out_features + +int in_features_per_rank + +bool reduce_results + +Parameter weight + +Optional[Parameter] bias +forward(x) Tensor + +load_state_dict(state_dict) } } @@ -938,6 +1031,10 @@ classDiagram AutoModel <|-- EmbeddingEncoder BaseConfig <|-- BaseModelConfig BaseConfig <|-- TrainConfig + BaseConfig <|-- InputConfig + BaseConfig <|-- ProcessingConfig + BaseConfig <|-- OutputConfig + BaseConfig <|-- PipelineConfig BaseModelConfig <|-- AutoRegressiveLMConfig BaseModelConfig <|-- EncoderConfig BaseFactory <|-- AutoModel @@ -950,11 +1047,15 @@ classDiagram BaseFactory <|-- StoreFactory BaseFactory <|-- ExecutorFactory BaseFactory <|-- ConfigFactory + BaseFactory <|-- MaskBuilderFactory BaseExecutor <|-- NoneExecutor BaseExecutor <|-- DDPExecutor BaseExecutor <|-- FSDPExecutor ResponseBuilder <|-- OpenAIResponseBuilder ResponseBuilder <|-- AnthropicResponseBuilder + BaseMaskBuilder <|-- ChatMaskBuilder + BaseMaskBuilder <|-- InstructionMaskBuilder + BaseMaskBuilder <|-- TextMaskBuilder %% --- Composition (strong ownership, part destroyed with whole) --- KVCache *-- PagePool @@ -994,6 +1095,8 @@ classDiagram %% --- Dependency (uses temporarily) --- TrainConfig ..> BaseStrategy : selects + PipelineConfig ..> MaskBuilderFactory : selects + MaskBuilderFactory ..> BaseMaskBuilder : creates StrategyFactory ..> BaseStrategy : creates SchedulerFactory ..> BaseScheduler : creates DatasetFactory ..> BaseDataset : creates @@ -1046,7 +1149,8 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| -| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | +| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) | +| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, ChatMaskBuilder, InstructionMaskBuilder, TextMaskBuilder, Pipeline, filter_by_length, dedup_signature | Declarative JSON-driven data preprocessing | | **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | @@ -1070,14 +1174,14 @@ classDiagram | **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Context** | `TrainContext` | Unified training state bag | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | -| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | +| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor`, `FSDPExecutor` | Gradient accumulation & model distribution | | **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | ## Core Relationships -1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` +1. **Config → Training**: `TrainConfig` holds `model_fn`, `dataset`, `optimizer_fn`, `scheduler_fn`, `parallel_mode`, `executor_kwargs` 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor` @@ -1089,4 +1193,4 @@ classDiagram 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers -> Document Update Time: 2026-05-28 +> Document Update Time: 2026-05-30 diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index fd1f53d..38facbd 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -5,7 +5,7 @@ This document describes the data pipeline: from raw text to model input tensors. ## Overview ``` -Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference +Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference ``` ## Data Preparation @@ -33,14 +33,21 @@ H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS pag ## Dataset Architecture ``` -DatasetFactory.load(train_type, load_path, window_size, stride, storage_type) - → StoreFactory.create(detect_format(path)) - → Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] - → BaseDataset.__getitem__(idx) - → sliding window [begin, end) via get_index(idx) +DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None) + → BaseDataset.load(load_path, storage_type=None) + → detect_format(load_path) + → StoreFactory.create(storage_type) + → Store.load(load_path) + → H5Store._normalize() / MmapStore._normalize() + → Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] + → BaseDataset.__getitem__(idx) + → get_index(idx) → [begin, end) + → Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor] ``` -`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`). +`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`). + +`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`. ## Sampler @@ -54,4 +61,4 @@ DatasetFactory.load(train_type, load_path, window_size, stride, storage_type) Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. -> Document Update Time: 2026-05-28 +> Document Update Time: 2026-05-30 diff --git a/assets/docs/inference.md b/assets/docs/inference.md index 0a1568e..54435c5 100644 --- a/assets/docs/inference.md +++ b/assets/docs/inference.md @@ -12,7 +12,7 @@ RoPE is applied **before** KV cache write, not after — otherwise position enco ## KVCache System -Six classes working together: +Six classes (plus two helpers) working together: ``` KVCache (facade) @@ -43,7 +43,8 @@ KVCache (facade) BaseSamplingStrategy (ABC) ├── TemperatureStrategy ├── TopKStrategy - └── TopPStrategy + ├── TopPStrategy + └── SamplingPipeline ``` `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. @@ -73,7 +74,9 @@ Adding a protocol = one builder file, no handler subclassing needed. InferenceEngine ├── generate(prompt, stream, ...) → str | List[str] | Generator ├── generate_with_request(req) → same - └── generate_async(prompt, ...) → AsyncGenerator + ├── generate_async(prompt, ...) → AsyncGenerator + ├── get_stats() → Dict + └── shutdown() ``` `GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`. @@ -124,9 +127,9 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`. | Param | Type | Default | Description | |-------|------|---------|-------------| | `messages` | List[dict] | required | Chat messages (role, content) | -| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) | -| `top_p` | float | 1.0 | Nucleus threshold | | `top_k` | int | 50 | Top-k count | +| `top_p` | float | 1.0 | Nucleus threshold | +| `temperature` | float | 1.0 | Sampling temperature (> 0.0) | | `max_tokens` | Optional[int] | None | Max generation length | | `stream` | bool | False | Stream output | @@ -142,7 +145,8 @@ engine.generate("Hello", stream=True) # -> Generator[str] engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]] # Async -await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] +async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str] + print(token) ``` -> Document Update Time: 2026-05-28 +> Document Update Time: 2026-05-30 diff --git a/assets/docs/params.md b/assets/docs/params.md index 218bafa..e3bf04f 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -75,6 +75,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ + --parallel_mode=ddp \ --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ diff --git a/assets/docs/preprocessing.md b/assets/docs/preprocessing.md index 2e3008d..ff983a9 100644 --- a/assets/docs/preprocessing.md +++ b/assets/docs/preprocessing.md @@ -147,10 +147,11 @@ For instruction mode, keys are `"prompt"` and `"response"`. For each message in the `messages` array: -1. Render through the chat template for that single message -2. Encode the rendered text, record token span `(start, end, role)` -3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries -4. Fill `loss_mask` from the mask rules +1. Prepend BOS token (position 0, always masked) +2. Render through the chat template for that single message +3. Encode the rendered text, record token span `(start, end, role)` +4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries +5. Fill `loss_mask` from the mask rules **Multi-turn example**: diff --git a/assets/docs/training.md b/assets/docs/training.md index edffacc..3dbaf0d 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -36,14 +36,16 @@ Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_ ``` on_train_begin + model.train() on_epoch_begin for batch in dataloader: on_batch_begin with executor.accumulate(model): - loss = strategy(batch) + loss = strategy.compute_loss(batch) + context.loss = loss.item() stand_loss = loss / executor.grad_accum_steps executor.backward(stand_loss) - iteration += 1 + context.iteration += 1 on_batch_end if executor.sync_gradients: @@ -61,9 +63,13 @@ on_train_end | Hook | Fires | Default callback | |------|-------|-----------------| | `on_train_begin` | Before training starts | `GradientCheckpointingCallback` | +| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` | +| `on_batch_begin` | Every batch | — | | `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | -| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | +| `on_epoch_end` | End of each epoch | `ProgressBarCallback` | +| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` | +| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` | Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset). @@ -77,7 +83,7 @@ $$ L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta) $$ -Keys: `input_ids`, `target_ids` +Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`. ### SFT (Supervised Fine-Tuning) @@ -87,7 +93,7 @@ $$ L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta) $$ -Keys: `input_ids`, `target_ids`, `loss_mask` +Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`. ### DPO (Direct Preference Optimization) @@ -97,7 +103,7 @@ $$ L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right] $$ -Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`. +Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`. ### GRPO (Group Relative Policy Optimization) @@ -111,7 +117,7 @@ $$ L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right] $$ -Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`. +Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`. Keys: `prompts`, `responses`, `masks`, `rewards`. @@ -122,7 +128,7 @@ Keys: `prompts`, `responses`, `masks`, `rewards`. | Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` | | SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) | -Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. +Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler. ## Gradient Checkpointing @@ -139,8 +145,8 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi ``` Checkpoint(state_dict, epoch, iteration, extra, meta, config) - ├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt) - └── load(save_dir) broadcasts metadata from rank-0 + ├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt) + └── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0 ``` Optimizer/scheduler state persisted by default via `Checkpoint.extra`. @@ -161,7 +167,7 @@ context = ( - Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` - Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers - Creates `ResumableDistributedSampler` for shuffle+resume -- Builds strategy via `StrategyFactory.create(train_type, ...)` +- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)` ## Training CLI @@ -170,6 +176,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ + --parallel_mode=ddp \ --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ @@ -191,4 +198,4 @@ nohup python scripts/tools/train.py \ Full parameter reference at [params.md](params.md). -> Document Update Time: 2026-05-28 +> Document Update Time: 2026-05-30