Compare commits
No commits in common. "d0e34646634c6daab79135a6e387afeb10565d29" and "10ebd7211fd38f0acf8ea8164dadf8316cb97634" have entirely different histories.
d0e3464663
...
10ebd7211f
|
|
@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=seq \
|
--train_type=pt \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
@ -90,8 +90,8 @@ nohup python scripts/tools/train.py \
|
||||||
--warmup_ratio=0.05 \
|
--warmup_ratio=0.05 \
|
||||||
--max_lr=1e-4 \
|
--max_lr=1e-4 \
|
||||||
--max_grad_norm=1.0 \
|
--max_grad_norm=1.0 \
|
||||||
--adamw_beta1=0.9 \
|
--adamw_beta1=0.95 \
|
||||||
--adamw_beta2=0.95 \
|
--adamw_beta2=0.99 \
|
||||||
--adamw_weight_decay=0.01 \
|
--adamw_weight_decay=0.01 \
|
||||||
--window_size=2048 \
|
--window_size=2048 \
|
||||||
--ckpt_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=seq \
|
--train_type=pt \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
@ -96,8 +96,8 @@ nohup python scripts/tools/train.py \
|
||||||
--warmup_ratio=0.05 \
|
--warmup_ratio=0.05 \
|
||||||
--max_lr=1e-4 \
|
--max_lr=1e-4 \
|
||||||
--max_grad_norm=1.0 \
|
--max_grad_norm=1.0 \
|
||||||
--adamw_beta1=0.9 \
|
--adamw_beta1=0.95 \
|
||||||
--adamw_beta2=0.95 \
|
--adamw_beta2=0.99 \
|
||||||
--adamw_weight_decay=0.01 \
|
--adamw_weight_decay=0.01 \
|
||||||
--window_size=2048 \
|
--window_size=2048 \
|
||||||
--ckpt_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ classDiagram
|
||||||
+to_file(config_path)
|
+to_file(config_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
class AutoRegressiveLMConfig {
|
class ModelConfig {
|
||||||
+int vocab_size
|
+int vocab_size
|
||||||
+int dim
|
+int dim
|
||||||
+int n_layers
|
+int n_layers
|
||||||
|
|
@ -25,41 +25,21 @@ classDiagram
|
||||||
+bool tie_weight
|
+bool tie_weight
|
||||||
+int max_len
|
+int max_len
|
||||||
+float rope_theta
|
+float rope_theta
|
||||||
+str attn_type
|
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+bool use_qk_norm
|
+bool use_qk_norm
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+Optional[int] kv_lora_rank
|
+str attn_type
|
||||||
+Optional[int] qk_nope_head_dim
|
|
||||||
+Optional[int] qk_rope_head_dim
|
|
||||||
+str ffn_type
|
+str ffn_type
|
||||||
+int n_routed_experts
|
+int n_routed_experts
|
||||||
+int n_shared_experts
|
+int n_shared_experts
|
||||||
+int n_activated_experts
|
+int n_activated_experts
|
||||||
+Optional[str] topk_method
|
+str moe_topk_method
|
||||||
}
|
+Optional[int] kv_lora_rank
|
||||||
|
+Optional[int] qk_nope_head_dim
|
||||||
class EncoderConfig {
|
+Optional[int] qk_rope_head_dim
|
||||||
+int vocab_size
|
+load(config_path) ModelConfig
|
||||||
+int dim
|
+save(config_path)
|
||||||
+int n_layers
|
|
||||||
+float norm_eps
|
|
||||||
+int dim_ffn
|
|
||||||
+int max_len
|
|
||||||
+float rope_theta
|
|
||||||
+int n_heads
|
|
||||||
+int n_kv_heads
|
|
||||||
+bool use_qk_norm
|
|
||||||
+bool use_gated_attention
|
|
||||||
+Optional[str] pooling_type
|
|
||||||
+Optional[bool] normalize_embeddings
|
|
||||||
}
|
|
||||||
|
|
||||||
class ConfigFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+load(raw) BaseConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainConfig {
|
class TrainConfig {
|
||||||
|
|
@ -72,7 +52,6 @@ classDiagram
|
||||||
+int batch_per_device
|
+int batch_per_device
|
||||||
+int grad_accum_steps
|
+int grad_accum_steps
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+list gradient_checkpointing_modules
|
|
||||||
+int start_epoch
|
+int start_epoch
|
||||||
+int start_batch
|
+int start_batch
|
||||||
+str ckpt_dir
|
+str ckpt_dir
|
||||||
|
|
@ -87,10 +66,7 @@ classDiagram
|
||||||
+str master_port
|
+str master_port
|
||||||
+Callable parallel_wrapper
|
+Callable parallel_wrapper
|
||||||
+Callable state_dict_fn
|
+Callable state_dict_fn
|
||||||
+str start_method
|
|
||||||
+str device_type
|
+str device_type
|
||||||
+Optional[Dataset] val_dataset
|
|
||||||
+int val_step
|
|
||||||
+dict extra_kwargs
|
+dict extra_kwargs
|
||||||
+validate()
|
+validate()
|
||||||
}
|
}
|
||||||
|
|
@ -162,17 +138,11 @@ classDiagram
|
||||||
+int iter
|
+int iter
|
||||||
}
|
}
|
||||||
|
|
||||||
class StorageFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+create(storage_type) BaseStorage
|
|
||||||
}
|
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -199,8 +169,8 @@ classDiagram
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
}
|
}
|
||||||
|
|
||||||
class AutoRegressiveLM {
|
class Transformer {
|
||||||
+AutoRegressiveLMConfig config
|
+ModelConfig config
|
||||||
+RotaryEmbedding rotary_embedding
|
+RotaryEmbedding rotary_embedding
|
||||||
+Embedding embed_tokens
|
+Embedding embed_tokens
|
||||||
+ModuleList layers
|
+ModuleList layers
|
||||||
|
|
@ -211,18 +181,6 @@ classDiagram
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
class EmbeddingEncoder {
|
|
||||||
+EncoderConfig config
|
|
||||||
+RotaryEmbedding rotary_embedding
|
|
||||||
+Embedding embed_tokens
|
|
||||||
+ModuleList layers
|
|
||||||
+RMSNorm norm
|
|
||||||
+str pooling_type
|
|
||||||
+bool normalize_embeddings
|
|
||||||
+forward(input_ids, input_mask, position_ids) Tensor
|
|
||||||
+load_state_dict(state_dict)
|
|
||||||
}
|
|
||||||
|
|
||||||
class DecoderBlock {
|
class DecoderBlock {
|
||||||
+nn.Module attention # GQA or MLA via AttnFactory
|
+nn.Module attention # GQA or MLA via AttnFactory
|
||||||
+RMSNorm input_norm
|
+RMSNorm input_norm
|
||||||
|
|
@ -364,15 +322,11 @@ classDiagram
|
||||||
+Optimizer optimizer
|
+Optimizer optimizer
|
||||||
+LRScheduler scheduler
|
+LRScheduler scheduler
|
||||||
+Checkpoint checkpoint
|
+Checkpoint checkpoint
|
||||||
+TrainConfig config
|
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
+float loss
|
+float loss
|
||||||
+DataLoader val_dataloader
|
|
||||||
+float val_loss
|
|
||||||
+int world_size
|
+int world_size
|
||||||
+int rank
|
+int rank
|
||||||
+dict kwargs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContextBuilder {
|
class TrainContextBuilder {
|
||||||
|
|
@ -461,12 +415,6 @@ classDiagram
|
||||||
+on_step_begin(context)
|
+on_step_begin(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class GradientCheckpointingCallback {
|
|
||||||
+tuple modules
|
|
||||||
+on_train_begin(context)
|
|
||||||
+on_train_end(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
|
|
@ -490,11 +438,6 @@ classDiagram
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class ValidationCallback {
|
|
||||||
+_run_validation(context)
|
|
||||||
+on_step_end(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class CallbackFactory {
|
class CallbackFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
|
|
@ -695,7 +638,6 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatCompletionRequest {
|
class ChatCompletionRequest {
|
||||||
+str model
|
|
||||||
+List[ChatMessage] messages
|
+List[ChatMessage] messages
|
||||||
+float temperature
|
+float temperature
|
||||||
+float top_p
|
+float top_p
|
||||||
|
|
@ -704,10 +646,6 @@ classDiagram
|
||||||
+bool stream
|
+bool stream
|
||||||
+Optional[str] stop
|
+Optional[str] stop
|
||||||
+Optional[int] n
|
+Optional[int] n
|
||||||
+Optional[float] presence_penalty
|
|
||||||
+Optional[float] frequency_penalty
|
|
||||||
+Optional[Dict] logit_bias
|
|
||||||
+Optional[str] user
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class AnthropicMessage {
|
class AnthropicMessage {
|
||||||
|
|
@ -761,7 +699,6 @@ classDiagram
|
||||||
+int completion_tokens
|
+int completion_tokens
|
||||||
+str accumulated
|
+str accumulated
|
||||||
+Optional[str] stop_matched
|
+Optional[str] stop_matched
|
||||||
+str last_yield_trimmed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class app {
|
class app {
|
||||||
|
|
@ -772,7 +709,7 @@ classDiagram
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class Functions {
|
class Functions {
|
||||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||||
+get_current_device() str
|
+get_current_device() str
|
||||||
+get_world_size() int
|
+get_world_size() int
|
||||||
|
|
@ -804,7 +741,6 @@ classDiagram
|
||||||
BaseScheduler <|-- CosineScheduler
|
BaseScheduler <|-- CosineScheduler
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
TrainCallback <|-- GradientClippingCallback
|
TrainCallback <|-- GradientClippingCallback
|
||||||
TrainCallback <|-- GradientCheckpointingCallback
|
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
|
|
@ -819,12 +755,10 @@ classDiagram
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
ParallelModel <|-- RowParallelLinear
|
ParallelModel <|-- RowParallelLinear
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoModel <|-- AutoRegressiveLM
|
AutoModel <|-- Transformer
|
||||||
AutoModel <|-- EmbeddingEncoder
|
|
||||||
BaseConfig <|-- BaseModelConfig
|
BaseConfig <|-- BaseModelConfig
|
||||||
BaseConfig <|-- TrainConfig
|
BaseConfig <|-- TrainConfig
|
||||||
BaseModelConfig <|-- AutoRegressiveLMConfig
|
BaseModelConfig <|-- ModelConfig
|
||||||
BaseModelConfig <|-- EncoderConfig
|
|
||||||
BaseFactory <|-- AutoModel
|
BaseFactory <|-- AutoModel
|
||||||
BaseFactory <|-- AttnFactory
|
BaseFactory <|-- AttnFactory
|
||||||
BaseFactory <|-- FFNFactory
|
BaseFactory <|-- FFNFactory
|
||||||
|
|
@ -832,9 +766,6 @@ classDiagram
|
||||||
BaseFactory <|-- StrategyFactory
|
BaseFactory <|-- StrategyFactory
|
||||||
BaseFactory <|-- SchedulerFactory
|
BaseFactory <|-- SchedulerFactory
|
||||||
BaseFactory <|-- CallbackFactory
|
BaseFactory <|-- CallbackFactory
|
||||||
BaseFactory <|-- StorageFactory
|
|
||||||
BaseFactory <|-- ConfigFactory
|
|
||||||
TrainCallback <|-- ValidationCallback
|
|
||||||
ProtocolHandler <|-- OpenAIHandler
|
ProtocolHandler <|-- OpenAIHandler
|
||||||
ProtocolHandler <|-- AnthropicHandler
|
ProtocolHandler <|-- AnthropicHandler
|
||||||
|
|
||||||
|
|
@ -850,16 +781,16 @@ classDiagram
|
||||||
InferenceScheduler *-- TaskManager
|
InferenceScheduler *-- TaskManager
|
||||||
SamplingPipeline *-- BaseSamplingStrategy
|
SamplingPipeline *-- BaseSamplingStrategy
|
||||||
TrainContextBuilder *-- TrainContext
|
TrainContextBuilder *-- TrainContext
|
||||||
AutoRegressiveLM *-- DecoderBlock
|
Transformer *-- DecoderBlock
|
||||||
AutoRegressiveLM *-- RotaryEmbedding
|
Transformer *-- RotaryEmbedding
|
||||||
AutoRegressiveLM *-- Embedding
|
Transformer *-- Embedding
|
||||||
DecoderBlock *-- RMSNorm
|
DecoderBlock *-- RMSNorm
|
||||||
BaseDataset *-- BaseStorage
|
BaseDataset *-- BaseStorage
|
||||||
ChatCompletionRequest *-- ChatMessage
|
ChatCompletionRequest *-- ChatMessage
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
|
|
||||||
%% --- Aggregation (weak ownership) ---
|
%% --- Aggregation (weak ownership) ---
|
||||||
AutoModel o-- BaseModelConfig
|
AutoModel o-- ModelConfig
|
||||||
Trainer o-- TrainCallback
|
Trainer o-- TrainCallback
|
||||||
TrainContext o-- BaseStrategy
|
TrainContext o-- BaseStrategy
|
||||||
TrainContext o-- BaseScheduler
|
TrainContext o-- BaseScheduler
|
||||||
|
|
@ -880,10 +811,6 @@ classDiagram
|
||||||
FFNFactory ..> DeepSeekMoE : creates
|
FFNFactory ..> DeepSeekMoE : creates
|
||||||
DecoderBlock ..> AttnFactory : uses
|
DecoderBlock ..> AttnFactory : uses
|
||||||
DecoderBlock ..> FFNFactory : uses
|
DecoderBlock ..> FFNFactory : uses
|
||||||
StorageFactory ..> H5Storage : creates
|
|
||||||
StorageFactory ..> JSONStorage : creates
|
|
||||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
|
||||||
ConfigFactory ..> EncoderConfig : creates
|
|
||||||
Trainer ..> TrainContextBuilder : uses
|
Trainer ..> TrainContextBuilder : uses
|
||||||
Trainer ..> Functions : spawns
|
Trainer ..> Functions : spawns
|
||||||
TrainContextBuilder ..> StrategyFactory : uses
|
TrainContextBuilder ..> StrategyFactory : uses
|
||||||
|
|
@ -900,13 +827,13 @@ classDiagram
|
||||||
|
|
||||||
%% --- Association (general usage) ---
|
%% --- Association (general usage) ---
|
||||||
Trainer --> TrainConfig
|
Trainer --> TrainConfig
|
||||||
DPOStrategy --> AutoRegressiveLM
|
DPOStrategy --> Transformer
|
||||||
GRPOStrategy --> AutoRegressiveLM
|
GRPOStrategy --> Transformer
|
||||||
InferenceScheduler --> Task
|
InferenceScheduler --> Task
|
||||||
InferenceScheduler --> TaskStatus
|
InferenceScheduler --> TaskStatus
|
||||||
Task --> TaskStatus
|
Task --> TaskStatus
|
||||||
InferenceEngine --> AutoRegressiveLM
|
InferenceEngine --> Transformer
|
||||||
Executor --> AutoRegressiveLM
|
Executor --> Transformer
|
||||||
Executor --> AutoTokenizer
|
Executor --> AutoTokenizer
|
||||||
TaskManager --> AutoTokenizer
|
TaskManager --> AutoTokenizer
|
||||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||||
|
|
@ -919,12 +846,12 @@ classDiagram
|
||||||
|
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–MetricLoggerCallback, CallbackFactory | Training workflow |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||||
|
|
@ -933,7 +860,7 @@ classDiagram
|
||||||
|
|
||||||
| Pattern | Classes | Purpose |
|
| Pattern | Classes | Purpose |
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||||
|
|
@ -944,18 +871,18 @@ classDiagram
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading |
|
||||||
|
|
||||||
## Core Relationships
|
## Core Relationships
|
||||||
|
|
||||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, backed by `KVCache` + `SamplingPipeline`
|
||||||
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||||
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
||||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-16
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or
|
||||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
|
|
||||||
```
|
```
|
||||||
StorageFactory.create("h5") → H5Storage
|
create_storage("h5") → H5Storage
|
||||||
StorageFactory.create("json") → JSONStorage
|
create_storage("json") → JSONStorage
|
||||||
```
|
```
|
||||||
|
|
||||||
Both support shared memory via `.share_memory_()`.
|
Both support shared memory via `.share_memory_()`.
|
||||||
|
|
@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`.
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, path, window_size, stride)
|
DatasetFactory.load(train_type, path, window_size, stride)
|
||||||
→ StorageFactory.create(detect_format(path))
|
→ create_storage(detect_format(path))
|
||||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||||
→ BaseDataset.__getitem__(idx)
|
→ BaseDataset.__getitem__(idx)
|
||||||
→ sliding window [begin, end) via get_index(idx)
|
→ sliding window [begin, end) via get_index(idx)
|
||||||
|
|
@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride)
|
||||||
|
|
||||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-15
|
||||||
|
|
|
||||||
|
|
@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-15
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,8 @@
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
| `--adamw_beta1` | AdamW beta1 | 0.95 |
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
| `--adamw_beta2` | AdamW beta2 | 0.99 |
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||||
|
|
||||||
### Data Loading
|
### Data Loading
|
||||||
|
|
@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=seq \
|
--train_type=pt \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
@ -81,8 +81,8 @@ nohup python scripts/tools/train.py \
|
||||||
--warmup_ratio=0.05 \
|
--warmup_ratio=0.05 \
|
||||||
--max_lr=1e-4 \
|
--max_lr=1e-4 \
|
||||||
--max_grad_norm=1.0 \
|
--max_grad_norm=1.0 \
|
||||||
--adamw_beta1=0.9 \
|
--adamw_beta1=0.95 \
|
||||||
--adamw_beta2=0.95 \
|
--adamw_beta2=0.99 \
|
||||||
--adamw_weight_decay=0.01 \
|
--adamw_weight_decay=0.01 \
|
||||||
--window_size=2048 \
|
--window_size=2048 \
|
||||||
--ckpt_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
|
|
@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-16
|
||||||
|
|
@ -91,13 +91,11 @@ on_train_end
|
||||||
|
|
||||||
| Hook | Fires | Default callback |
|
| Hook | Fires | Default callback |
|
||||||
|------|-------|-----------------|
|
|------|-------|-----------------|
|
||||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
|
||||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
| `on_step_end` | Every accumulation window | `ValidationCallback` |
|
|
||||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||||
|
|
||||||
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||||||
|
|
||||||
## Strategies
|
## Strategies
|
||||||
|
|
||||||
|
|
@ -156,17 +154,6 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
|
|
||||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||||||
|
|
||||||
## Gradient Checkpointing
|
|
||||||
|
|
||||||
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
|
||||||
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
|
||||||
```
|
|
||||||
|
|
||||||
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
|
||||||
|
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
@ -201,7 +188,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--train_type=seq \
|
--train_type=pt \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
--batch_per_device=4 \
|
--batch_per_device=4 \
|
||||||
|
|
@ -209,8 +196,8 @@ nohup python scripts/tools/train.py \
|
||||||
--warmup_ratio=0.05 \
|
--warmup_ratio=0.05 \
|
||||||
--max_lr=1e-4 \
|
--max_lr=1e-4 \
|
||||||
--max_grad_norm=1.0 \
|
--max_grad_norm=1.0 \
|
||||||
--adamw_beta1=0.9 \
|
--adamw_beta1=0.95 \
|
||||||
--adamw_beta2=0.95 \
|
--adamw_beta2=0.99 \
|
||||||
--adamw_weight_decay=0.01 \
|
--adamw_weight_decay=0.01 \
|
||||||
--window_size=2048 \
|
--window_size=2048 \
|
||||||
--ckpt_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
|
|
@ -222,4 +209,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-16
|
||||||
|
|
|
||||||
|
|
@ -39,10 +39,6 @@ class TrainConfig(BaseConfig):
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||||
)
|
)
|
||||||
gradient_checkpointing_modules: list = field(
|
|
||||||
default_factory=list,
|
|
||||||
metadata={"help": "Module types to enable activation checkpointing for."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# checkpoint setting
|
# checkpoint setting
|
||||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -91,41 +90,6 @@ class GradientClippingCallback(TrainCallback):
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("gradient_checkpointing")
|
|
||||||
class GradientCheckpointingCallback(TrainCallback):
|
|
||||||
"""
|
|
||||||
Activation checkpointing callback — trades compute for memory
|
|
||||||
by recomputing specified module activations during the backward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
modules: Module types to apply checkpointing to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, modules: Optional[List[type]] = None):
|
|
||||||
self.modules = tuple(modules) if modules else ()
|
|
||||||
|
|
||||||
def _enable(self, module: nn.Module):
|
|
||||||
if self.modules and isinstance(module, self.modules):
|
|
||||||
fn = module.forward
|
|
||||||
module._original_forward = fn
|
|
||||||
module.forward = lambda *a, **kw: torch_checkpoint(
|
|
||||||
fn, *a, use_reentrant=False, **kw
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _disable(module: nn.Module):
|
|
||||||
if hasattr(module, "_original_forward"):
|
|
||||||
module.forward = module._original_forward
|
|
||||||
del module._original_forward
|
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
|
||||||
context.model.apply(self._enable)
|
|
||||||
logger.info("Gradient checkpointing enabled")
|
|
||||||
|
|
||||||
def on_train_end(self, context: TrainContext):
|
|
||||||
context.model.apply(self._disable)
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("checkpoint")
|
@CallbackFactory.register("checkpoint")
|
||||||
class CheckpointCallback(TrainCallback):
|
class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -25,11 +25,7 @@ class Trainer:
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
callbacks = [
|
return [
|
||||||
CallbackFactory.create(
|
|
||||||
"gradient_checkpointing",
|
|
||||||
modules=cfg.gradient_checkpointing_modules,
|
|
||||||
),
|
|
||||||
CallbackFactory.create(
|
CallbackFactory.create(
|
||||||
"checkpoint",
|
"checkpoint",
|
||||||
cfg.ckpt_dir,
|
cfg.ckpt_dir,
|
||||||
|
|
@ -41,7 +37,6 @@ class Trainer:
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
CallbackFactory.create("validation"),
|
CallbackFactory.create("validation"),
|
||||||
]
|
]
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
|
|
||||||
|
|
@ -69,14 +69,14 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adamw_beta1",
|
"--adamw_beta1",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.9,
|
default=0.95,
|
||||||
help="Beta1 for AdamW optimizer.",
|
help="Beta values for AdamW optimizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adamw_beta2",
|
"--adamw_beta2",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.95,
|
default=0.99,
|
||||||
help="Beta2 for AdamW optimizer.",
|
help="Beta values for AdamW optimizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adamw_weight_decay",
|
"--adamw_weight_decay",
|
||||||
|
|
|
||||||
|
|
@ -1,130 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
|
||||||
from astrai.trainer.schedule import SchedulerFactory
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback
|
from astrai.trainer.train_callback import TrainCallback
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_checkpointing_enable_disable(test_model):
|
|
||||||
"""Enable wraps forward, _disable restores it."""
|
|
||||||
model = test_model["model"]
|
|
||||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
|
||||||
|
|
||||||
originals = [layer.forward for layer in model.layers]
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
callback._enable(layer)
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
assert hasattr(layer, "_original_forward")
|
|
||||||
assert layer.forward is not originals[0]
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
callback._disable(layer)
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
assert not hasattr(layer, "_original_forward")
|
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_checkpointing_empty_modules_noop(test_model):
|
|
||||||
"""modules=None should leave forwards untouched."""
|
|
||||||
model = test_model["model"]
|
|
||||||
callback = GradientCheckpointingCallback()
|
|
||||||
|
|
||||||
originals = [layer.forward for layer in model.layers]
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
callback._enable(layer)
|
|
||||||
|
|
||||||
for layer, orig in zip(model.layers, originals):
|
|
||||||
assert layer.forward is orig
|
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_checkpointing_forward_unchanged(test_model):
|
|
||||||
"""Forward output unchanged after patching (no_grad)."""
|
|
||||||
model = test_model["model"]
|
|
||||||
device = test_model["device"]
|
|
||||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
|
||||||
|
|
||||||
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
ref = model(input_ids)["logits"].clone()
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
callback._enable(layer)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out = model(input_ids)["logits"]
|
|
||||||
|
|
||||||
assert torch.equal(ref, out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_checkpointing_backward(test_model):
|
|
||||||
"""backward passes gradients through checkpointed layers."""
|
|
||||||
model = test_model["model"]
|
|
||||||
device = test_model["device"]
|
|
||||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
callback._enable(layer)
|
|
||||||
|
|
||||||
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
|
||||||
target_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
|
||||||
|
|
||||||
logits = model(input_ids)["logits"]
|
|
||||||
loss = torch.nn.functional.cross_entropy(
|
|
||||||
logits.flatten(0, 1).float(), target_ids.flatten()
|
|
||||||
)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.requires_grad:
|
|
||||||
assert param.grad is not None, f"{name} gradient is None"
|
|
||||||
|
|
||||||
for layer in model.layers:
|
|
||||||
callback._disable(layer)
|
|
||||||
|
|
||||||
model.zero_grad()
|
|
||||||
for name, p in model.named_parameters():
|
|
||||||
assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed"
|
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset):
|
|
||||||
"""Gradient checkpointing runs end-to-end via Trainer."""
|
|
||||||
|
|
||||||
def optimizer_fn(model):
|
|
||||||
return torch.optim.AdamW(model.parameters())
|
|
||||||
|
|
||||||
def scheduler_fn(optim):
|
|
||||||
return SchedulerFactory.create(
|
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
|
||||||
)
|
|
||||||
|
|
||||||
train_config = TrainConfig(
|
|
||||||
model=base_test_env["model"],
|
|
||||||
strategy="seq",
|
|
||||||
dataset=random_dataset,
|
|
||||||
optimizer_fn=optimizer_fn,
|
|
||||||
scheduler_fn=scheduler_fn,
|
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
|
||||||
n_epoch=1,
|
|
||||||
batch_per_device=2,
|
|
||||||
ckpt_interval=3,
|
|
||||||
grad_accum_steps=1,
|
|
||||||
max_grad_norm=1.0,
|
|
||||||
random_seed=42,
|
|
||||||
device_type=base_test_env["device"],
|
|
||||||
gradient_checkpointing_modules=[DecoderBlock],
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
|
||||||
trainer.train()
|
|
||||||
# no crash = callback correctly enabled/disabled
|
|
||||||
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue