docs: 修正文档中类名/字段名与代码不一致之处
- ModelConfig → AutoRegressiveLMConfig, Transformer → AutoRegressiveLM - 新增缺失类: EncoderConfig, EmbeddingEncoder, ConfigFactory, StorageFactory, ValidationCallback - TrainConfig/TrainContext/ChatCompletionRequest 补充缺失字段 - dataflow.md 中 create_storage → StorageFactory.create - 示例 --train_type=pt → seq 与代码一致
This commit is contained in:
parent
2c2697390d
commit
d0e3464663
|
|
@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ classDiagram
|
|||
+to_file(config_path)
|
||||
}
|
||||
|
||||
class ModelConfig {
|
||||
class AutoRegressiveLMConfig {
|
||||
+int vocab_size
|
||||
+int dim
|
||||
+int n_layers
|
||||
|
|
@ -25,21 +25,41 @@ classDiagram
|
|||
+bool tie_weight
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+str attn_type
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+str attn_type
|
||||
+Optional[int] kv_lora_rank
|
||||
+Optional[int] qk_nope_head_dim
|
||||
+Optional[int] qk_rope_head_dim
|
||||
+str ffn_type
|
||||
+int n_routed_experts
|
||||
+int n_shared_experts
|
||||
+int n_activated_experts
|
||||
+str moe_topk_method
|
||||
+Optional[int] kv_lora_rank
|
||||
+Optional[int] qk_nope_head_dim
|
||||
+Optional[int] qk_rope_head_dim
|
||||
+load(config_path) ModelConfig
|
||||
+save(config_path)
|
||||
+Optional[str] topk_method
|
||||
}
|
||||
|
||||
class EncoderConfig {
|
||||
+int vocab_size
|
||||
+int dim
|
||||
+int n_layers
|
||||
+float norm_eps
|
||||
+int dim_ffn
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Optional[str] pooling_type
|
||||
+Optional[bool] normalize_embeddings
|
||||
}
|
||||
|
||||
class ConfigFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+load(raw) BaseConfig
|
||||
}
|
||||
|
||||
class TrainConfig {
|
||||
|
|
@ -52,6 +72,7 @@ classDiagram
|
|||
+int batch_per_device
|
||||
+int grad_accum_steps
|
||||
+float max_grad_norm
|
||||
+list gradient_checkpointing_modules
|
||||
+int start_epoch
|
||||
+int start_batch
|
||||
+str ckpt_dir
|
||||
|
|
@ -66,7 +87,10 @@ classDiagram
|
|||
+str master_port
|
||||
+Callable parallel_wrapper
|
||||
+Callable state_dict_fn
|
||||
+str start_method
|
||||
+str device_type
|
||||
+Optional[Dataset] val_dataset
|
||||
+int val_step
|
||||
+dict extra_kwargs
|
||||
+validate()
|
||||
}
|
||||
|
|
@ -138,11 +162,17 @@ classDiagram
|
|||
+int iter
|
||||
}
|
||||
|
||||
class StorageFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(storage_type) BaseStorage
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(train_type, window_size, stride) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -169,8 +199,8 @@ classDiagram
|
|||
+to(*args, **kwargs) Self
|
||||
}
|
||||
|
||||
class Transformer {
|
||||
+ModelConfig config
|
||||
class AutoRegressiveLM {
|
||||
+AutoRegressiveLMConfig config
|
||||
+RotaryEmbedding rotary_embedding
|
||||
+Embedding embed_tokens
|
||||
+ModuleList layers
|
||||
|
|
@ -181,6 +211,18 @@ classDiagram
|
|||
+state_dict()
|
||||
}
|
||||
|
||||
class EmbeddingEncoder {
|
||||
+EncoderConfig config
|
||||
+RotaryEmbedding rotary_embedding
|
||||
+Embedding embed_tokens
|
||||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
+str pooling_type
|
||||
+bool normalize_embeddings
|
||||
+forward(input_ids, input_mask, position_ids) Tensor
|
||||
+load_state_dict(state_dict)
|
||||
}
|
||||
|
||||
class DecoderBlock {
|
||||
+nn.Module attention # GQA or MLA via AttnFactory
|
||||
+RMSNorm input_norm
|
||||
|
|
@ -322,11 +364,15 @@ classDiagram
|
|||
+Optimizer optimizer
|
||||
+LRScheduler scheduler
|
||||
+Checkpoint checkpoint
|
||||
+TrainConfig config
|
||||
+int epoch
|
||||
+int iteration
|
||||
+float loss
|
||||
+DataLoader val_dataloader
|
||||
+float val_loss
|
||||
+int world_size
|
||||
+int rank
|
||||
+dict kwargs
|
||||
}
|
||||
|
||||
class TrainContextBuilder {
|
||||
|
|
@ -415,6 +461,12 @@ classDiagram
|
|||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class GradientCheckpointingCallback {
|
||||
+tuple modules
|
||||
+on_train_begin(context)
|
||||
+on_train_end(context)
|
||||
}
|
||||
|
||||
class CheckpointCallback {
|
||||
+str save_dir
|
||||
+int interval
|
||||
|
|
@ -438,6 +490,11 @@ classDiagram
|
|||
+on_train_end(context)
|
||||
}
|
||||
|
||||
class ValidationCallback {
|
||||
+_run_validation(context)
|
||||
+on_step_end(context)
|
||||
}
|
||||
|
||||
class CallbackFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
|
|
@ -638,6 +695,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class ChatCompletionRequest {
|
||||
+str model
|
||||
+List[ChatMessage] messages
|
||||
+float temperature
|
||||
+float top_p
|
||||
|
|
@ -646,6 +704,10 @@ classDiagram
|
|||
+bool stream
|
||||
+Optional[str] stop
|
||||
+Optional[int] n
|
||||
+Optional[float] presence_penalty
|
||||
+Optional[float] frequency_penalty
|
||||
+Optional[Dict] logit_bias
|
||||
+Optional[str] user
|
||||
}
|
||||
|
||||
class AnthropicMessage {
|
||||
|
|
@ -699,6 +761,7 @@ classDiagram
|
|||
+int completion_tokens
|
||||
+str accumulated
|
||||
+Optional[str] stop_matched
|
||||
+str last_yield_trimmed
|
||||
}
|
||||
|
||||
class app {
|
||||
|
|
@ -709,7 +772,7 @@ classDiagram
|
|||
|
||||
namespace parallel {
|
||||
class Functions {
|
||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
+get_current_device() str
|
||||
+get_world_size() int
|
||||
|
|
@ -741,6 +804,7 @@ classDiagram
|
|||
BaseScheduler <|-- CosineScheduler
|
||||
BaseScheduler <|-- SGDRScheduler
|
||||
TrainCallback <|-- GradientClippingCallback
|
||||
TrainCallback <|-- GradientCheckpointingCallback
|
||||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
|
|
@ -755,10 +819,12 @@ classDiagram
|
|||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoModel <|-- Transformer
|
||||
AutoModel <|-- AutoRegressiveLM
|
||||
AutoModel <|-- EmbeddingEncoder
|
||||
BaseConfig <|-- BaseModelConfig
|
||||
BaseConfig <|-- TrainConfig
|
||||
BaseModelConfig <|-- ModelConfig
|
||||
BaseModelConfig <|-- AutoRegressiveLMConfig
|
||||
BaseModelConfig <|-- EncoderConfig
|
||||
BaseFactory <|-- AutoModel
|
||||
BaseFactory <|-- AttnFactory
|
||||
BaseFactory <|-- FFNFactory
|
||||
|
|
@ -766,6 +832,9 @@ classDiagram
|
|||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
BaseFactory <|-- CallbackFactory
|
||||
BaseFactory <|-- StorageFactory
|
||||
BaseFactory <|-- ConfigFactory
|
||||
TrainCallback <|-- ValidationCallback
|
||||
ProtocolHandler <|-- OpenAIHandler
|
||||
ProtocolHandler <|-- AnthropicHandler
|
||||
|
||||
|
|
@ -781,16 +850,16 @@ classDiagram
|
|||
InferenceScheduler *-- TaskManager
|
||||
SamplingPipeline *-- BaseSamplingStrategy
|
||||
TrainContextBuilder *-- TrainContext
|
||||
Transformer *-- DecoderBlock
|
||||
Transformer *-- RotaryEmbedding
|
||||
Transformer *-- Embedding
|
||||
AutoRegressiveLM *-- DecoderBlock
|
||||
AutoRegressiveLM *-- RotaryEmbedding
|
||||
AutoRegressiveLM *-- Embedding
|
||||
DecoderBlock *-- RMSNorm
|
||||
BaseDataset *-- BaseStorage
|
||||
ChatCompletionRequest *-- ChatMessage
|
||||
MessagesRequest *-- AnthropicMessage
|
||||
|
||||
%% --- Aggregation (weak ownership) ---
|
||||
AutoModel o-- ModelConfig
|
||||
AutoModel o-- BaseModelConfig
|
||||
Trainer o-- TrainCallback
|
||||
TrainContext o-- BaseStrategy
|
||||
TrainContext o-- BaseScheduler
|
||||
|
|
@ -811,6 +880,10 @@ classDiagram
|
|||
FFNFactory ..> DeepSeekMoE : creates
|
||||
DecoderBlock ..> AttnFactory : uses
|
||||
DecoderBlock ..> FFNFactory : uses
|
||||
StorageFactory ..> H5Storage : creates
|
||||
StorageFactory ..> JSONStorage : creates
|
||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||
ConfigFactory ..> EncoderConfig : creates
|
||||
Trainer ..> TrainContextBuilder : uses
|
||||
Trainer ..> Functions : spawns
|
||||
TrainContextBuilder ..> StrategyFactory : uses
|
||||
|
|
@ -827,13 +900,13 @@ classDiagram
|
|||
|
||||
%% --- Association (general usage) ---
|
||||
Trainer --> TrainConfig
|
||||
DPOStrategy --> Transformer
|
||||
GRPOStrategy --> Transformer
|
||||
DPOStrategy --> AutoRegressiveLM
|
||||
GRPOStrategy --> AutoRegressiveLM
|
||||
InferenceScheduler --> Task
|
||||
InferenceScheduler --> TaskStatus
|
||||
Task --> TaskStatus
|
||||
InferenceEngine --> Transformer
|
||||
Executor --> Transformer
|
||||
InferenceEngine --> AutoRegressiveLM
|
||||
Executor --> AutoRegressiveLM
|
||||
Executor --> AutoTokenizer
|
||||
TaskManager --> AutoTokenizer
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||
|
|
@ -846,12 +919,12 @@ classDiagram
|
|||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–MetricLoggerCallback, CallbackFactory | Training workflow |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||
|
|
@ -860,7 +933,7 @@ classDiagram
|
|||
|
||||
| Pattern | Classes | Purpose |
|
||||
|---------|---------|---------|
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation |
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||
|
|
@ -871,18 +944,18 @@ classDiagram
|
|||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading |
|
||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||
|
||||
## Core Relationships
|
||||
|
||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, backed by `KVCache` + `SamplingPipeline`
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
|
||||
> Document Update Time: 2026-05-16
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or
|
|||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||
|
||||
```
|
||||
create_storage("h5") → H5Storage
|
||||
create_storage("json") → JSONStorage
|
||||
StorageFactory.create("h5") → H5Storage
|
||||
StorageFactory.create("json") → JSONStorage
|
||||
```
|
||||
|
||||
Both support shared memory via `.share_memory_()`.
|
||||
|
|
@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`.
|
|||
|
||||
```
|
||||
DatasetFactory.load(train_type, path, window_size, stride)
|
||||
→ create_storage(detect_format(path))
|
||||
→ StorageFactory.create(detect_format(path))
|
||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ sliding window [begin, end) via get_index(idx)
|
||||
|
|
@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride)
|
|||
|
||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
|
|
@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
|||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \
|
|||
|
||||
---
|
||||
|
||||
> Document Update Time: 2026-05-16
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
@ -91,11 +91,13 @@ on_train_end
|
|||
|
||||
| Hook | Fires | Default callback |
|
||||
|------|-------|-----------------|
|
||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||
| `on_step_end` | Every accumulation window | `ValidationCallback` |
|
||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||
|
||||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||||
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||
|
||||
## Strategies
|
||||
|
||||
|
|
@ -154,6 +156,17 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
|||
|
||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
||||
|
||||
```python
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
||||
```
|
||||
|
||||
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
||||
|
||||
## Checkpoint
|
||||
|
||||
```
|
||||
|
|
@ -188,7 +201,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
@ -209,4 +222,4 @@ nohup python scripts/tools/train.py \
|
|||
|
||||
Full parameter reference at [params.md](params.md).
|
||||
|
||||
> Document Update Time: 2026-05-16
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
Loading…
Reference in New Issue