docs: 修正文档中类名/字段名与代码不一致之处

- ModelConfig → AutoRegressiveLMConfig, Transformer → AutoRegressiveLM
- 新增缺失类: EncoderConfig, EmbeddingEncoder, ConfigFactory, StorageFactory, ValidationCallback
- TrainConfig/TrainContext/ChatCompletionRequest 补充缺失字段
- dataflow.md 中 create_storage → StorageFactory.create
- 示例 --train_type=pt → seq 与代码一致
This commit is contained in:
ViperEkura 2026-05-17 20:23:12 +08:00
parent 2c2697390d
commit 6c8533f1d2
7 changed files with 161 additions and 48 deletions

View File

@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--train_type=pt \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \
--batch_per_device=4 \ --batch_per_device=4 \

View File

@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--train_type=pt \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \
--batch_per_device=4 \ --batch_per_device=4 \

View File

@ -16,7 +16,7 @@ classDiagram
+to_file(config_path) +to_file(config_path)
} }
class ModelConfig { class AutoRegressiveLMConfig {
+int vocab_size +int vocab_size
+int dim +int dim
+int n_layers +int n_layers
@ -25,21 +25,41 @@ classDiagram
+bool tie_weight +bool tie_weight
+int max_len +int max_len
+float rope_theta +float rope_theta
+str attn_type
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+bool use_qk_norm +bool use_qk_norm
+bool use_gated_attention +bool use_gated_attention
+str attn_type +Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+str ffn_type +str ffn_type
+int n_routed_experts +int n_routed_experts
+int n_shared_experts +int n_shared_experts
+int n_activated_experts +int n_activated_experts
+str moe_topk_method +Optional[str] topk_method
+Optional[int] kv_lora_rank }
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim class EncoderConfig {
+load(config_path) ModelConfig +int vocab_size
+save(config_path) +int dim
+int n_layers
+float norm_eps
+int dim_ffn
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
class ConfigFactory {
+Registry _registry
+register(name) decorator
+load(raw) BaseConfig
} }
class TrainConfig { class TrainConfig {
@ -52,6 +72,7 @@ classDiagram
+int batch_per_device +int batch_per_device
+int grad_accum_steps +int grad_accum_steps
+float max_grad_norm +float max_grad_norm
+list gradient_checkpointing_modules
+int start_epoch +int start_epoch
+int start_batch +int start_batch
+str ckpt_dir +str ckpt_dir
@ -66,7 +87,10 @@ classDiagram
+str master_port +str master_port
+Callable parallel_wrapper +Callable parallel_wrapper
+Callable state_dict_fn +Callable state_dict_fn
+str start_method
+str device_type +str device_type
+Optional[Dataset] val_dataset
+int val_step
+dict extra_kwargs +dict extra_kwargs
+validate() +validate()
} }
@ -138,11 +162,17 @@ classDiagram
+int iter +int iter
} }
class StorageFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) BaseStorage
}
class DatasetFactory { class DatasetFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
+create(train_type, window_size, stride) BaseDataset +create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset +load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
} }
} }
@ -160,7 +190,7 @@ classDiagram
namespace model { namespace model {
class AutoModel { class AutoModel {
+ModelConfig config +BaseModelConfig config
+Registry _registry +Registry _registry
+register(model_type) decorator +register(model_type) decorator
+get_component_class(model_type) Type +get_component_class(model_type) Type
@ -169,8 +199,8 @@ classDiagram
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
} }
class Transformer { class AutoRegressiveLM {
+ModelConfig config +AutoRegressiveLMConfig config
+RotaryEmbedding rotary_embedding +RotaryEmbedding rotary_embedding
+Embedding embed_tokens +Embedding embed_tokens
+ModuleList layers +ModuleList layers
@ -181,6 +211,18 @@ classDiagram
+state_dict() +state_dict()
} }
class EmbeddingEncoder {
+EncoderConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+str pooling_type
+bool normalize_embeddings
+forward(input_ids, input_mask, position_ids) Tensor
+load_state_dict(state_dict)
}
class DecoderBlock { class DecoderBlock {
+nn.Module attention # GQA or MLA via AttnFactory +nn.Module attention # GQA or MLA via AttnFactory
+RMSNorm input_norm +RMSNorm input_norm
@ -322,11 +364,15 @@ classDiagram
+Optimizer optimizer +Optimizer optimizer
+LRScheduler scheduler +LRScheduler scheduler
+Checkpoint checkpoint +Checkpoint checkpoint
+TrainConfig config
+int epoch +int epoch
+int iteration +int iteration
+float loss +float loss
+DataLoader val_dataloader
+float val_loss
+int world_size +int world_size
+int rank +int rank
+dict kwargs
} }
class TrainContextBuilder { class TrainContextBuilder {
@ -372,6 +418,7 @@ classDiagram
+str reduction +str reduction
+int sync_interval +int sync_interval
+compute_loss(batch) Tensor +compute_loss(batch) Tensor
+sync_ref_model()
} }
class BaseScheduler { class BaseScheduler {
@ -399,6 +446,7 @@ classDiagram
} }
class TrainCallback { class TrainCallback {
<<protocol>>
+on_train_begin(context) +on_train_begin(context)
+on_train_end(context) +on_train_end(context)
+on_epoch_begin(context) +on_epoch_begin(context)
@ -415,13 +463,22 @@ classDiagram
+on_step_begin(context) +on_step_begin(context)
} }
class GradientCheckpointingCallback {
+tuple modules
+on_train_begin(context)
+on_train_end(context)
}
class CheckpointCallback { class CheckpointCallback {
+str save_dir +str save_dir
+int interval +int interval
+_save_checkpoint(context) +_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context) +on_error(context)
+save_extra(context)$
+load_extra(extra, context)$
} }
class ProgressBarCallback { class ProgressBarCallback {
@ -436,6 +493,12 @@ classDiagram
+int save_interval +int save_interval
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context)
}
class ValidationCallback {
+_run_validation(context)
+on_step_end(context)
} }
class CallbackFactory { class CallbackFactory {
@ -443,6 +506,14 @@ classDiagram
+register(name) decorator +register(name) decorator
+create(name, **kwargs) TrainCallback +create(name, **kwargs) TrainCallback
} }
class Muon {
+float lr
+float momentum
+float weight_decay
+int ns_steps
+step(closure) Optional[float]
}
} }
namespace inference { namespace inference {
@ -638,14 +709,19 @@ classDiagram
} }
class ChatCompletionRequest { class ChatCompletionRequest {
+str model
+List[ChatMessage] messages +List[ChatMessage] messages
+float temperature +float temperature
+float top_p +float top_p
+int top_k +int top_k
+int max_tokens +int max_tokens
+bool stream +bool stream
+Optional[str] stop +Optional[Union[str, List[str]]] stop
+Optional[int] n +Optional[int] n
+Optional[float] presence_penalty
+Optional[float] frequency_penalty
+Optional[Dict] logit_bias
+Optional[str] user
} }
class AnthropicMessage { class AnthropicMessage {
@ -654,6 +730,7 @@ classDiagram
} }
class MessagesRequest { class MessagesRequest {
+str model
+List[AnthropicMessage] messages +List[AnthropicMessage] messages
+Optional[str] system +Optional[str] system
+float temperature +float temperature
@ -666,8 +743,13 @@ classDiagram
class ProtocolHandler { class ProtocolHandler {
<<abstract>> <<abstract>>
+request
+engine
+build_prompt() str +build_prompt() str
+create_response_id() str +create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str] +format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str +format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str] +format_stream_end(ctx) List[str]
@ -687,6 +769,7 @@ classDiagram
} }
class StopChecker { class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str] +check(text) Optional[str]
+trim(text, matched) str +trim(text, matched) str
} }
@ -699,6 +782,7 @@ classDiagram
+int completion_tokens +int completion_tokens
+str accumulated +str accumulated
+Optional[str] stop_matched +Optional[str] stop_matched
+str last_yield_trimmed
} }
class app { class app {
@ -709,11 +793,13 @@ classDiagram
namespace parallel { namespace parallel {
class Functions { class Functions {
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs) <<module>>
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str +get_current_device() str
+get_world_size() int +get_world_size() int
+get_rank() int +get_rank() int
+only_on_rank(rank, sync) decorator
} }
class ParallelModel { class ParallelModel {
@ -741,6 +827,7 @@ classDiagram
BaseScheduler <|-- CosineScheduler BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
TrainCallback <|-- GradientClippingCallback TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- GradientCheckpointingCallback
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
@ -753,12 +840,15 @@ classDiagram
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy BaseSamplingStrategy <|-- TopPStrategy
BaseSamplingStrategy <|-- SamplingPipeline
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer AutoModel <|-- AutoRegressiveLM
AutoModel <|-- EmbeddingEncoder
BaseConfig <|-- BaseModelConfig BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig BaseConfig <|-- TrainConfig
BaseModelConfig <|-- ModelConfig BaseModelConfig <|-- AutoRegressiveLMConfig
BaseModelConfig <|-- EncoderConfig
BaseFactory <|-- AutoModel BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory BaseFactory <|-- AttnFactory
BaseFactory <|-- FFNFactory BaseFactory <|-- FFNFactory
@ -766,6 +856,9 @@ classDiagram
BaseFactory <|-- StrategyFactory BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- ConfigFactory
TrainCallback <|-- ValidationCallback
ProtocolHandler <|-- OpenAIHandler ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler ProtocolHandler <|-- AnthropicHandler
@ -773,24 +866,26 @@ classDiagram
KVCache *-- PagePool KVCache *-- PagePool
KVCache *-- Storage KVCache *-- Storage
KVCache *-- TaskTable KVCache *-- TaskTable
KVCache *-- Allocator PagePool *-- Allocator
KVCache *-- PrefixCache PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy SamplingPipeline *-- BaseSamplingStrategy
TrainContextBuilder *-- TrainContext AutoRegressiveLM *-- DecoderBlock
Transformer *-- DecoderBlock AutoRegressiveLM *-- RotaryEmbedding
Transformer *-- RotaryEmbedding AutoRegressiveLM *-- Embedding
Transformer *-- Embedding EmbeddingEncoder *-- DecoderBlock
EmbeddingEncoder *-- RotaryEmbedding
EmbeddingEncoder *-- Embedding
DecoderBlock *-- RMSNorm DecoderBlock *-- RMSNorm
BaseDataset *-- BaseStorage BaseDataset o-- BaseStorage
ChatCompletionRequest *-- ChatMessage ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage MessagesRequest *-- AnthropicMessage
%% --- Aggregation (weak ownership) --- %% --- Aggregation (weak ownership) ---
AutoModel o-- ModelConfig AutoModel o-- BaseModelConfig
Trainer o-- TrainCallback Trainer o-- TrainCallback
TrainContext o-- BaseStrategy TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler TrainContext o-- BaseScheduler
@ -811,7 +906,12 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses DecoderBlock ..> FFNFactory : uses
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
Trainer ..> TrainContextBuilder : uses Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
Trainer ..> Functions : spawns Trainer ..> Functions : spawns
TrainContextBuilder ..> StrategyFactory : uses TrainContextBuilder ..> StrategyFactory : uses
TrainContextBuilder ..> ResumableDistributedSampler : creates TrainContextBuilder ..> ResumableDistributedSampler : creates
@ -827,13 +927,13 @@ classDiagram
%% --- Association (general usage) --- %% --- Association (general usage) ---
Trainer --> TrainConfig Trainer --> TrainConfig
DPOStrategy --> Transformer DPOStrategy --> AutoModel
GRPOStrategy --> Transformer GRPOStrategy --> AutoModel
InferenceScheduler --> Task InferenceScheduler --> Task
InferenceScheduler --> TaskStatus InferenceScheduler --> TaskStatus
Task --> TaskStatus Task --> TaskStatus
InferenceEngine --> Transformer InferenceEngine --> AutoModel
Executor --> Transformer Executor --> AutoModel
Executor --> AutoTokenizer Executor --> AutoTokenizer
TaskManager --> AutoTokenizer TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher MultiSegmentFetcher --> BaseSegmentFetcher
@ -846,12 +946,12 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization | | **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallbackMetricLoggerCallback, CallbackFactory | Training workflow | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, ChatMessageMessagesRequest, app | Inference service | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration | | **astrai.factory** | Registry, BaseFactory[T] | Component registration |
@ -860,7 +960,7 @@ classDiagram
| Pattern | Classes | Purpose | | Pattern | Classes | Purpose |
|---------|---------|---------| |---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation | | **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
@ -871,18 +971,18 @@ classDiagram
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | | **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
## Core Relationships ## Core Relationships
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, backed by `KVCache` + `SamplingPipeline` 4. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only) 7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-16 > Document Update Time: 2026-05-17

View File

@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
``` ```
create_storage("h5") → H5Storage StorageFactory.create("h5") → H5Storage
create_storage("json") → JSONStorage StorageFactory.create("json") → JSONStorage
``` ```
Both support shared memory via `.share_memory_()`. Both support shared memory via `.share_memory_()`.
@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`.
``` ```
DatasetFactory.load(train_type, path, window_size, stride) DatasetFactory.load(train_type, path, window_size, stride)
create_storage(detect_format(path)) StorageFactory.create(detect_format(path))
→ MultiSegmentFetcher(BaseSegmentFetcher per key) → MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx) → BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx) → sliding window [begin, end) via get_index(idx)
@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride)
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-15 > Document Update Time: 2026-05-17

View File

@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
``` ```
> Document Update Time: 2026-05-15 > Document Update Time: 2026-05-17

View File

@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--train_type=pt \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \
--batch_per_device=4 \ --batch_per_device=4 \
@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \
--- ---
> Document Update Time: 2026-05-16 > Document Update Time: 2026-05-17

View File

@ -91,11 +91,13 @@ on_train_end
| Hook | Fires | Default callback | | Hook | Fires | Default callback |
|------|-------|-----------------| |------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` | | `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_step_end` | Every accumulation window | `ValidationCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | | `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`. Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies ## Strategies
@ -154,6 +156,17 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Gradient Checkpointing
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
```python
from astrai.model.components.decoder_block import DecoderBlock
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
```
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
## Checkpoint ## Checkpoint
``` ```
@ -188,7 +201,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \ nohup python scripts/tools/train.py \
--nprocs=4 \ --nprocs=4 \
--train_type=pt \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/model \
--batch_per_device=4 \ --batch_per_device=4 \
@ -209,4 +222,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md). Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-16 > Document Update Time: 2026-05-17