docs: 修正文档中类名/字段名与代码不一致之处
- ModelConfig → AutoRegressiveLMConfig, Transformer → AutoRegressiveLM - 新增缺失类: EncoderConfig, EmbeddingEncoder, ConfigFactory, StorageFactory, ValidationCallback - TrainConfig/TrainContext/ChatCompletionRequest 补充缺失字段 - dataflow.md 中 create_storage → StorageFactory.create - 示例 --train_type=pt → seq 与代码一致
This commit is contained in:
parent
2c2697390d
commit
d0e3464663
|
|
@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=pt \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=pt \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ classDiagram
|
||||||
+to_file(config_path)
|
+to_file(config_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
class ModelConfig {
|
class AutoRegressiveLMConfig {
|
||||||
+int vocab_size
|
+int vocab_size
|
||||||
+int dim
|
+int dim
|
||||||
+int n_layers
|
+int n_layers
|
||||||
|
|
@ -25,21 +25,41 @@ classDiagram
|
||||||
+bool tie_weight
|
+bool tie_weight
|
||||||
+int max_len
|
+int max_len
|
||||||
+float rope_theta
|
+float rope_theta
|
||||||
|
+str attn_type
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+bool use_qk_norm
|
+bool use_qk_norm
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+str attn_type
|
+Optional[int] kv_lora_rank
|
||||||
|
+Optional[int] qk_nope_head_dim
|
||||||
|
+Optional[int] qk_rope_head_dim
|
||||||
+str ffn_type
|
+str ffn_type
|
||||||
+int n_routed_experts
|
+int n_routed_experts
|
||||||
+int n_shared_experts
|
+int n_shared_experts
|
||||||
+int n_activated_experts
|
+int n_activated_experts
|
||||||
+str moe_topk_method
|
+Optional[str] topk_method
|
||||||
+Optional[int] kv_lora_rank
|
}
|
||||||
+Optional[int] qk_nope_head_dim
|
|
||||||
+Optional[int] qk_rope_head_dim
|
class EncoderConfig {
|
||||||
+load(config_path) ModelConfig
|
+int vocab_size
|
||||||
+save(config_path)
|
+int dim
|
||||||
|
+int n_layers
|
||||||
|
+float norm_eps
|
||||||
|
+int dim_ffn
|
||||||
|
+int max_len
|
||||||
|
+float rope_theta
|
||||||
|
+int n_heads
|
||||||
|
+int n_kv_heads
|
||||||
|
+bool use_qk_norm
|
||||||
|
+bool use_gated_attention
|
||||||
|
+Optional[str] pooling_type
|
||||||
|
+Optional[bool] normalize_embeddings
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConfigFactory {
|
||||||
|
+Registry _registry
|
||||||
|
+register(name) decorator
|
||||||
|
+load(raw) BaseConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainConfig {
|
class TrainConfig {
|
||||||
|
|
@ -52,6 +72,7 @@ classDiagram
|
||||||
+int batch_per_device
|
+int batch_per_device
|
||||||
+int grad_accum_steps
|
+int grad_accum_steps
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
|
+list gradient_checkpointing_modules
|
||||||
+int start_epoch
|
+int start_epoch
|
||||||
+int start_batch
|
+int start_batch
|
||||||
+str ckpt_dir
|
+str ckpt_dir
|
||||||
|
|
@ -66,7 +87,10 @@ classDiagram
|
||||||
+str master_port
|
+str master_port
|
||||||
+Callable parallel_wrapper
|
+Callable parallel_wrapper
|
||||||
+Callable state_dict_fn
|
+Callable state_dict_fn
|
||||||
|
+str start_method
|
||||||
+str device_type
|
+str device_type
|
||||||
|
+Optional[Dataset] val_dataset
|
||||||
|
+int val_step
|
||||||
+dict extra_kwargs
|
+dict extra_kwargs
|
||||||
+validate()
|
+validate()
|
||||||
}
|
}
|
||||||
|
|
@ -138,11 +162,17 @@ classDiagram
|
||||||
+int iter
|
+int iter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class StorageFactory {
|
||||||
|
+Registry _registry
|
||||||
|
+register(name) decorator
|
||||||
|
+create(storage_type) BaseStorage
|
||||||
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -169,8 +199,8 @@ classDiagram
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
}
|
}
|
||||||
|
|
||||||
class Transformer {
|
class AutoRegressiveLM {
|
||||||
+ModelConfig config
|
+AutoRegressiveLMConfig config
|
||||||
+RotaryEmbedding rotary_embedding
|
+RotaryEmbedding rotary_embedding
|
||||||
+Embedding embed_tokens
|
+Embedding embed_tokens
|
||||||
+ModuleList layers
|
+ModuleList layers
|
||||||
|
|
@ -181,6 +211,18 @@ classDiagram
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class EmbeddingEncoder {
|
||||||
|
+EncoderConfig config
|
||||||
|
+RotaryEmbedding rotary_embedding
|
||||||
|
+Embedding embed_tokens
|
||||||
|
+ModuleList layers
|
||||||
|
+RMSNorm norm
|
||||||
|
+str pooling_type
|
||||||
|
+bool normalize_embeddings
|
||||||
|
+forward(input_ids, input_mask, position_ids) Tensor
|
||||||
|
+load_state_dict(state_dict)
|
||||||
|
}
|
||||||
|
|
||||||
class DecoderBlock {
|
class DecoderBlock {
|
||||||
+nn.Module attention # GQA or MLA via AttnFactory
|
+nn.Module attention # GQA or MLA via AttnFactory
|
||||||
+RMSNorm input_norm
|
+RMSNorm input_norm
|
||||||
|
|
@ -322,11 +364,15 @@ classDiagram
|
||||||
+Optimizer optimizer
|
+Optimizer optimizer
|
||||||
+LRScheduler scheduler
|
+LRScheduler scheduler
|
||||||
+Checkpoint checkpoint
|
+Checkpoint checkpoint
|
||||||
|
+TrainConfig config
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
+float loss
|
+float loss
|
||||||
|
+DataLoader val_dataloader
|
||||||
|
+float val_loss
|
||||||
+int world_size
|
+int world_size
|
||||||
+int rank
|
+int rank
|
||||||
|
+dict kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContextBuilder {
|
class TrainContextBuilder {
|
||||||
|
|
@ -415,6 +461,12 @@ classDiagram
|
||||||
+on_step_begin(context)
|
+on_step_begin(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class GradientCheckpointingCallback {
|
||||||
|
+tuple modules
|
||||||
|
+on_train_begin(context)
|
||||||
|
+on_train_end(context)
|
||||||
|
}
|
||||||
|
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
|
|
@ -438,6 +490,11 @@ classDiagram
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ValidationCallback {
|
||||||
|
+_run_validation(context)
|
||||||
|
+on_step_end(context)
|
||||||
|
}
|
||||||
|
|
||||||
class CallbackFactory {
|
class CallbackFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
|
|
@ -638,6 +695,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatCompletionRequest {
|
class ChatCompletionRequest {
|
||||||
|
+str model
|
||||||
+List[ChatMessage] messages
|
+List[ChatMessage] messages
|
||||||
+float temperature
|
+float temperature
|
||||||
+float top_p
|
+float top_p
|
||||||
|
|
@ -646,6 +704,10 @@ classDiagram
|
||||||
+bool stream
|
+bool stream
|
||||||
+Optional[str] stop
|
+Optional[str] stop
|
||||||
+Optional[int] n
|
+Optional[int] n
|
||||||
|
+Optional[float] presence_penalty
|
||||||
|
+Optional[float] frequency_penalty
|
||||||
|
+Optional[Dict] logit_bias
|
||||||
|
+Optional[str] user
|
||||||
}
|
}
|
||||||
|
|
||||||
class AnthropicMessage {
|
class AnthropicMessage {
|
||||||
|
|
@ -699,6 +761,7 @@ classDiagram
|
||||||
+int completion_tokens
|
+int completion_tokens
|
||||||
+str accumulated
|
+str accumulated
|
||||||
+Optional[str] stop_matched
|
+Optional[str] stop_matched
|
||||||
|
+str last_yield_trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
class app {
|
class app {
|
||||||
|
|
@ -709,7 +772,7 @@ classDiagram
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class Functions {
|
class Functions {
|
||||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||||
+get_current_device() str
|
+get_current_device() str
|
||||||
+get_world_size() int
|
+get_world_size() int
|
||||||
|
|
@ -741,6 +804,7 @@ classDiagram
|
||||||
BaseScheduler <|-- CosineScheduler
|
BaseScheduler <|-- CosineScheduler
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
TrainCallback <|-- GradientClippingCallback
|
TrainCallback <|-- GradientClippingCallback
|
||||||
|
TrainCallback <|-- GradientCheckpointingCallback
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
|
|
@ -755,10 +819,12 @@ classDiagram
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
ParallelModel <|-- RowParallelLinear
|
ParallelModel <|-- RowParallelLinear
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoModel <|-- Transformer
|
AutoModel <|-- AutoRegressiveLM
|
||||||
|
AutoModel <|-- EmbeddingEncoder
|
||||||
BaseConfig <|-- BaseModelConfig
|
BaseConfig <|-- BaseModelConfig
|
||||||
BaseConfig <|-- TrainConfig
|
BaseConfig <|-- TrainConfig
|
||||||
BaseModelConfig <|-- ModelConfig
|
BaseModelConfig <|-- AutoRegressiveLMConfig
|
||||||
|
BaseModelConfig <|-- EncoderConfig
|
||||||
BaseFactory <|-- AutoModel
|
BaseFactory <|-- AutoModel
|
||||||
BaseFactory <|-- AttnFactory
|
BaseFactory <|-- AttnFactory
|
||||||
BaseFactory <|-- FFNFactory
|
BaseFactory <|-- FFNFactory
|
||||||
|
|
@ -766,6 +832,9 @@ classDiagram
|
||||||
BaseFactory <|-- StrategyFactory
|
BaseFactory <|-- StrategyFactory
|
||||||
BaseFactory <|-- SchedulerFactory
|
BaseFactory <|-- SchedulerFactory
|
||||||
BaseFactory <|-- CallbackFactory
|
BaseFactory <|-- CallbackFactory
|
||||||
|
BaseFactory <|-- StorageFactory
|
||||||
|
BaseFactory <|-- ConfigFactory
|
||||||
|
TrainCallback <|-- ValidationCallback
|
||||||
ProtocolHandler <|-- OpenAIHandler
|
ProtocolHandler <|-- OpenAIHandler
|
||||||
ProtocolHandler <|-- AnthropicHandler
|
ProtocolHandler <|-- AnthropicHandler
|
||||||
|
|
||||||
|
|
@ -781,16 +850,16 @@ classDiagram
|
||||||
InferenceScheduler *-- TaskManager
|
InferenceScheduler *-- TaskManager
|
||||||
SamplingPipeline *-- BaseSamplingStrategy
|
SamplingPipeline *-- BaseSamplingStrategy
|
||||||
TrainContextBuilder *-- TrainContext
|
TrainContextBuilder *-- TrainContext
|
||||||
Transformer *-- DecoderBlock
|
AutoRegressiveLM *-- DecoderBlock
|
||||||
Transformer *-- RotaryEmbedding
|
AutoRegressiveLM *-- RotaryEmbedding
|
||||||
Transformer *-- Embedding
|
AutoRegressiveLM *-- Embedding
|
||||||
DecoderBlock *-- RMSNorm
|
DecoderBlock *-- RMSNorm
|
||||||
BaseDataset *-- BaseStorage
|
BaseDataset *-- BaseStorage
|
||||||
ChatCompletionRequest *-- ChatMessage
|
ChatCompletionRequest *-- ChatMessage
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
|
|
||||||
%% --- Aggregation (weak ownership) ---
|
%% --- Aggregation (weak ownership) ---
|
||||||
AutoModel o-- ModelConfig
|
AutoModel o-- BaseModelConfig
|
||||||
Trainer o-- TrainCallback
|
Trainer o-- TrainCallback
|
||||||
TrainContext o-- BaseStrategy
|
TrainContext o-- BaseStrategy
|
||||||
TrainContext o-- BaseScheduler
|
TrainContext o-- BaseScheduler
|
||||||
|
|
@ -811,6 +880,10 @@ classDiagram
|
||||||
FFNFactory ..> DeepSeekMoE : creates
|
FFNFactory ..> DeepSeekMoE : creates
|
||||||
DecoderBlock ..> AttnFactory : uses
|
DecoderBlock ..> AttnFactory : uses
|
||||||
DecoderBlock ..> FFNFactory : uses
|
DecoderBlock ..> FFNFactory : uses
|
||||||
|
StorageFactory ..> H5Storage : creates
|
||||||
|
StorageFactory ..> JSONStorage : creates
|
||||||
|
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||||
|
ConfigFactory ..> EncoderConfig : creates
|
||||||
Trainer ..> TrainContextBuilder : uses
|
Trainer ..> TrainContextBuilder : uses
|
||||||
Trainer ..> Functions : spawns
|
Trainer ..> Functions : spawns
|
||||||
TrainContextBuilder ..> StrategyFactory : uses
|
TrainContextBuilder ..> StrategyFactory : uses
|
||||||
|
|
@ -827,13 +900,13 @@ classDiagram
|
||||||
|
|
||||||
%% --- Association (general usage) ---
|
%% --- Association (general usage) ---
|
||||||
Trainer --> TrainConfig
|
Trainer --> TrainConfig
|
||||||
DPOStrategy --> Transformer
|
DPOStrategy --> AutoRegressiveLM
|
||||||
GRPOStrategy --> Transformer
|
GRPOStrategy --> AutoRegressiveLM
|
||||||
InferenceScheduler --> Task
|
InferenceScheduler --> Task
|
||||||
InferenceScheduler --> TaskStatus
|
InferenceScheduler --> TaskStatus
|
||||||
Task --> TaskStatus
|
Task --> TaskStatus
|
||||||
InferenceEngine --> Transformer
|
InferenceEngine --> AutoRegressiveLM
|
||||||
Executor --> Transformer
|
Executor --> AutoRegressiveLM
|
||||||
Executor --> AutoTokenizer
|
Executor --> AutoTokenizer
|
||||||
TaskManager --> AutoTokenizer
|
TaskManager --> AutoTokenizer
|
||||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||||
|
|
@ -846,12 +919,12 @@ classDiagram
|
||||||
|
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–MetricLoggerCallback, CallbackFactory | Training workflow |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||||
|
|
@ -860,7 +933,7 @@ classDiagram
|
||||||
|
|
||||||
| Pattern | Classes | Purpose |
|
| Pattern | Classes | Purpose |
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation |
|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||||
|
|
@ -871,18 +944,18 @@ classDiagram
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||||
|
|
||||||
## Core Relationships
|
## Core Relationships
|
||||||
|
|
||||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, backed by `KVCache` + `SamplingPipeline`
|
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||||
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||||
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
||||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
|
|
||||||
> Document Update Time: 2026-05-16
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or
|
||||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
|
|
||||||
```
|
```
|
||||||
create_storage("h5") → H5Storage
|
StorageFactory.create("h5") → H5Storage
|
||||||
create_storage("json") → JSONStorage
|
StorageFactory.create("json") → JSONStorage
|
||||||
```
|
```
|
||||||
|
|
||||||
Both support shared memory via `.share_memory_()`.
|
Both support shared memory via `.share_memory_()`.
|
||||||
|
|
@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`.
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, path, window_size, stride)
|
DatasetFactory.load(train_type, path, window_size, stride)
|
||||||
→ create_storage(detect_format(path))
|
→ StorageFactory.create(detect_format(path))
|
||||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||||
→ BaseDataset.__getitem__(idx)
|
→ BaseDataset.__getitem__(idx)
|
||||||
→ sliding window [begin, end) via get_index(idx)
|
→ sliding window [begin, end) via get_index(idx)
|
||||||
|
|
@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride)
|
||||||
|
|
||||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-15
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
|
|
@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-05-15
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=pt \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
> Document Update Time: 2026-05-16
|
> Document Update Time: 2026-05-17
|
||||||
|
|
@ -91,11 +91,13 @@ on_train_end
|
||||||
|
|
||||||
| Hook | Fires | Default callback |
|
| Hook | Fires | Default callback |
|
||||||
|------|-------|-----------------|
|
|------|-------|-----------------|
|
||||||
|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
|
| `on_step_end` | Every accumulation window | `ValidationCallback` |
|
||||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||||
|
|
||||||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||||
|
|
||||||
## Strategies
|
## Strategies
|
||||||
|
|
||||||
|
|
@ -154,6 +156,17 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
|
|
||||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||||||
|
|
||||||
|
## Gradient Checkpointing
|
||||||
|
|
||||||
|
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
||||||
|
```
|
||||||
|
|
||||||
|
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
||||||
|
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
@ -188,7 +201,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=pt \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
@ -209,4 +222,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-16
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue