docs: 更新文档与代码同步(Executor/训练循环/参数)
- architecture.md: TrainConfig 移除旧 parallel_wrapper/state_dict_fn - architecture.md: 新增 ExecutorFactory/BaseExecutor/DDPExecutor 等类图 - architecture.md: MLA 新增 use_qk_norm/q_norm/k_norm - architecture.md: 新增 protocols 命名空间 - training.md: 修复训练循环 hook 名和 scheduler.step 位置 - training.md: 替换 parallel_wrapper 为 parallel_mode/executor.prepare - training.md: 修复默认回调顺序和 Callback 生命周期表 - params.md: 新增 --parallel_mode 和 --start_method
This commit is contained in:
parent
7fa69572c0
commit
82a3f2626f
|
|
@ -88,12 +88,12 @@ classDiagram
|
||||||
+str backend
|
+str backend
|
||||||
+str master_addr
|
+str master_addr
|
||||||
+str master_port
|
+str master_port
|
||||||
+Callable parallel_wrapper
|
|
||||||
+Callable state_dict_fn
|
|
||||||
+str start_method
|
+str start_method
|
||||||
+str device_type
|
+str device_type
|
||||||
+Optional[Dataset] val_dataset
|
+Optional[Dataset] val_dataset
|
||||||
+int val_step
|
+int val_step
|
||||||
|
+str parallel_mode
|
||||||
|
+dict executor_kwargs
|
||||||
+dict extra_kwargs
|
+dict extra_kwargs
|
||||||
+validate()
|
+validate()
|
||||||
}
|
}
|
||||||
|
|
@ -257,11 +257,13 @@ classDiagram
|
||||||
+int qk_rope_head_dim
|
+int qk_rope_head_dim
|
||||||
+int n_rep
|
+int n_rep
|
||||||
+int layer_id
|
+int layer_id
|
||||||
|
+bool use_qk_norm
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||||
+Linear o_proj
|
+Linear o_proj
|
||||||
+Linear gate # only if use_gated_attention
|
+Linear gate # only if use_gated_attention
|
||||||
+RMSNorm kv_norm
|
+RMSNorm kv_norm
|
||||||
|
+RMSNorm q_norm, k_norm # only if use_qk_norm
|
||||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -364,10 +366,11 @@ classDiagram
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+BaseStrategy strategy
|
+BaseStrategy strategy
|
||||||
+DataLoader dataloader
|
+DataLoader dataloader
|
||||||
+Optimizer optimizer
|
+OptimizerProtocol optimizer
|
||||||
+LRScheduler scheduler
|
+SchedulerProtocol scheduler
|
||||||
+Checkpoint checkpoint
|
+Checkpoint checkpoint
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
|
+BaseExecutor executor
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
+float loss
|
+float loss
|
||||||
|
|
@ -802,6 +805,24 @@ classDiagram
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace protocols {
|
||||||
|
class OptimizerProtocol {
|
||||||
|
<<protocol>>
|
||||||
|
+step(closure)
|
||||||
|
+zero_grad()
|
||||||
|
+state_dict() dict
|
||||||
|
+load_state_dict(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
class SchedulerProtocol {
|
||||||
|
<<protocol>>
|
||||||
|
+step()
|
||||||
|
+state_dict() dict
|
||||||
|
+load_state_dict(d)
|
||||||
|
+get_last_lr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class Functions {
|
class Functions {
|
||||||
<<module>>
|
<<module>>
|
||||||
|
|
@ -813,6 +834,54 @@ classDiagram
|
||||||
+only_on_rank(rank, sync) decorator
|
+only_on_rank(rank, sync) decorator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class GradientState {
|
||||||
|
+int num_steps
|
||||||
|
+sync_gradients (property) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
class AccumOptimizer {
|
||||||
|
+Optimizer optimizer
|
||||||
|
+GradientState gradient_state
|
||||||
|
+step(closure)
|
||||||
|
+zero_grad()
|
||||||
|
+state_dict() dict
|
||||||
|
+load_state_dict(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
class AccumScheduler {
|
||||||
|
+LRScheduler scheduler
|
||||||
|
+GradientState gradient_state
|
||||||
|
+step()
|
||||||
|
+state_dict() dict
|
||||||
|
+load_state_dict(d)
|
||||||
|
+get_last_lr()
|
||||||
|
}
|
||||||
|
|
||||||
|
class BaseExecutor {
|
||||||
|
+GradientState gradient_state
|
||||||
|
+prepare(model, optimizer, dataloader, scheduler) tuple
|
||||||
|
+accumulate(model) context manager
|
||||||
|
+backward(loss)
|
||||||
|
+unwrap_model(model) nn.Module
|
||||||
|
+sync_gradients (property) bool
|
||||||
|
+grad_accum_steps (property) int
|
||||||
|
}
|
||||||
|
|
||||||
|
class NoneExecutor {
|
||||||
|
}
|
||||||
|
|
||||||
|
class DDPExecutor {
|
||||||
|
+_prepare_model(model) nn.Module
|
||||||
|
+_no_sync(model) context manager
|
||||||
|
+unwrap_model(model) nn.Module
|
||||||
|
}
|
||||||
|
|
||||||
|
class ExecutorFactory {
|
||||||
|
+Registry _registry
|
||||||
|
+register(name) decorator
|
||||||
|
+create(parallel_mode, **kwargs) BaseExecutor
|
||||||
|
}
|
||||||
|
|
||||||
class ParallelModel {
|
class ParallelModel {
|
||||||
+dist.ProcessGroup process_group
|
+dist.ProcessGroup process_group
|
||||||
+int rank
|
+int rank
|
||||||
|
|
@ -868,8 +937,10 @@ classDiagram
|
||||||
BaseFactory <|-- SchedulerFactory
|
BaseFactory <|-- SchedulerFactory
|
||||||
BaseFactory <|-- CallbackFactory
|
BaseFactory <|-- CallbackFactory
|
||||||
BaseFactory <|-- StorageFactory
|
BaseFactory <|-- StorageFactory
|
||||||
|
BaseFactory <|-- ExecutorFactory
|
||||||
BaseFactory <|-- ConfigFactory
|
BaseFactory <|-- ConfigFactory
|
||||||
TrainCallback <|-- ValidationCallback
|
BaseExecutor <|-- NoneExecutor
|
||||||
|
BaseExecutor <|-- DDPExecutor
|
||||||
ProtocolHandler <|-- OpenAIHandler
|
ProtocolHandler <|-- OpenAIHandler
|
||||||
ProtocolHandler <|-- AnthropicHandler
|
ProtocolHandler <|-- AnthropicHandler
|
||||||
|
|
||||||
|
|
@ -894,6 +965,9 @@ classDiagram
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
AutoTokenizer *-- ChatTemplate
|
AutoTokenizer *-- ChatTemplate
|
||||||
BaseFactory *-- Registry
|
BaseFactory *-- Registry
|
||||||
|
BaseExecutor *-- GradientState
|
||||||
|
AccumOptimizer o-- GradientState
|
||||||
|
AccumScheduler o-- GradientState
|
||||||
|
|
||||||
%% --- Aggregation (weak ownership) ---
|
%% --- Aggregation (weak ownership) ---
|
||||||
AutoModel o-- BaseModelConfig
|
AutoModel o-- BaseModelConfig
|
||||||
|
|
@ -901,6 +975,7 @@ classDiagram
|
||||||
TrainContext o-- BaseStrategy
|
TrainContext o-- BaseStrategy
|
||||||
TrainContext o-- BaseScheduler
|
TrainContext o-- BaseScheduler
|
||||||
TrainContext o-- Checkpoint
|
TrainContext o-- Checkpoint
|
||||||
|
TrainContext o-- BaseExecutor
|
||||||
KvcacheView o-- Storage
|
KvcacheView o-- Storage
|
||||||
SamplingPipeline o-- BaseSamplingStrategy
|
SamplingPipeline o-- BaseSamplingStrategy
|
||||||
BaseDataset o-- BaseStorage
|
BaseDataset o-- BaseStorage
|
||||||
|
|
@ -921,6 +996,9 @@ classDiagram
|
||||||
StorageFactory ..> JSONStorage : creates
|
StorageFactory ..> JSONStorage : creates
|
||||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||||
ConfigFactory ..> EncoderConfig : creates
|
ConfigFactory ..> EncoderConfig : creates
|
||||||
|
ExecutorFactory ..> NoneExecutor : creates
|
||||||
|
ExecutorFactory ..> DDPExecutor : creates
|
||||||
|
TrainContextBuilder ..> ExecutorFactory : creates
|
||||||
Trainer ..> TrainContextBuilder : uses
|
Trainer ..> TrainContextBuilder : uses
|
||||||
TrainContextBuilder ..> TrainContext : creates
|
TrainContextBuilder ..> TrainContext : creates
|
||||||
Trainer ..> Functions : spawns
|
Trainer ..> Functions : spawns
|
||||||
|
|
@ -963,15 +1041,16 @@ classDiagram
|
||||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, StopChecker, StreamContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||||
|
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
|
||||||
|
|
||||||
## Design Patterns
|
## Design Patterns
|
||||||
|
|
||||||
| Pattern | Classes | Purpose |
|
| Pattern | Classes | Purpose |
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||||
|
|
@ -980,20 +1059,23 @@ classDiagram
|
||||||
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
||||||
| **Context** | `TrainContext` | Unified training state bag |
|
| **Context** | `TrainContext` | Unified training state bag |
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
|
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
||||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||||
|
|
||||||
## Core Relationships
|
## Core Relationships
|
||||||
|
|
||||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
|
||||||
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
|
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
|
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-24
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,9 @@
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||||
|
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
|
||||||
| `--device_type` | Device type | cuda |
|
| `--device_type` | Device type | cuda |
|
||||||
|
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||||
|
|
||||||
### Strategy-specific
|
### Strategy-specific
|
||||||
|
|
||||||
|
|
@ -94,4 +96,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-24
|
||||||
|
|
@ -72,17 +72,18 @@ on_train_begin
|
||||||
on_epoch_begin
|
on_epoch_begin
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
on_batch_begin
|
on_batch_begin
|
||||||
|
with executor.accumulate(model):
|
||||||
loss = strategy(batch)
|
loss = strategy(batch)
|
||||||
(loss / grad_accum_steps).backward()
|
(loss / grad_accum_steps).backward()
|
||||||
iteration += 1
|
iteration += 1
|
||||||
on_batch_end
|
on_batch_end
|
||||||
|
|
||||||
if iteration % grad_accum_steps == 0:
|
if executor.sync_gradients:
|
||||||
on_step_begin
|
on_optimizer_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
on_step_end
|
|
||||||
scheduler.step()
|
scheduler.step() # called every iteration
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
```
|
```
|
||||||
|
|
@ -92,12 +93,11 @@ on_train_end
|
||||||
| Hook | Fires | Default callback |
|
| Hook | Fires | Default callback |
|
||||||
|------|-------|-----------------|
|
|------|-------|-----------------|
|
||||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
| `on_step_end` | Every accumulation window | `ValidationCallback` |
|
|
||||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||||
|
|
||||||
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||||
|
|
||||||
## Strategies
|
## Strategies
|
||||||
|
|
||||||
|
|
@ -171,7 +171,7 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
|
||||||
|
|
||||||
```
|
```
|
||||||
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
||||||
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
|
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
|
||||||
└── load(save_dir) broadcasts metadata from rank-0
|
└── load(save_dir) broadcasts metadata from rank-0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -190,7 +190,8 @@ context = (
|
||||||
```
|
```
|
||||||
|
|
||||||
- Loads checkpoint weights if provided
|
- Loads checkpoint weights if provided
|
||||||
- Wraps model with `parallel_wrapper` if `nprocs > 1`
|
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
|
||||||
|
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||||
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
||||||
|
|
||||||
|
|
@ -222,4 +223,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-24
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue