Compare commits
No commits in common. "ad9f4d9cf60f35cf742509b8096c7b541252c5be" and "3d12a03909c6dedc6de112a4f53e3ecd1d1a2068" have entirely different histories.
ad9f4d9cf6
...
3d12a03909
30
README.md
30
README.md
|
|
@ -78,27 +78,15 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
nohup python scripts/tools/train.py \
|
--data_root_path /path/to/dataset \
|
||||||
--nprocs=4 \
|
--param_path /path/to/model \
|
||||||
--train_type=pt \
|
--batch_size 4 \
|
||||||
--data_root_path=/path/to/dataset \
|
--accumulation_steps 8 \
|
||||||
--param_path=/path/to/model \
|
--max_lr 3e-4 \
|
||||||
--batch_per_device=4 \
|
--warmup_steps 1000 \
|
||||||
--grad_accum_steps=8 \
|
--n_epoch 1
|
||||||
--warmup_ratio=0.05 \
|
|
||||||
--max_lr=1e-4 \
|
|
||||||
--max_grad_norm=1.0 \
|
|
||||||
--adamw_beta1=0.95 \
|
|
||||||
--adamw_beta2=0.99 \
|
|
||||||
--adamw_weight_decay=0.01 \
|
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
|
|
|
||||||
|
|
@ -84,27 +84,15 @@ python scripts/demo/download.py
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
nohup python scripts/tools/train.py \
|
--data_root_path /path/to/dataset \
|
||||||
--nprocs=4 \
|
--param_path /path/to/model \
|
||||||
--train_type=pt \
|
--batch_size 4 \
|
||||||
--data_root_path=/path/to/dataset \
|
--accumulation_steps 8 \
|
||||||
--param_path=/path/to/model \
|
--max_lr 3e-4 \
|
||||||
--batch_per_device=4 \
|
--warmup_steps 1000 \
|
||||||
--grad_accum_steps=8 \
|
--n_epoch 1
|
||||||
--warmup_ratio=0.05 \
|
|
||||||
--max_lr=1e-4 \
|
|
||||||
--max_grad_norm=1.0 \
|
|
||||||
--adamw_beta1=0.95 \
|
|
||||||
--adamw_beta2=0.99 \
|
|
||||||
--adamw_weight_decay=0.01 \
|
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
```
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md)。
|
完整参数列表见[参数说明](./params.md)。
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,10 @@
|
||||||
```mermaid
|
```mermaid
|
||||||
classDiagram
|
classDiagram
|
||||||
namespace config {
|
namespace config {
|
||||||
class BaseConfig {
|
|
||||||
+to_dict() Dict
|
|
||||||
+from_dict(d) Self
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseModelConfig {
|
class BaseModelConfig {
|
||||||
+Optional[str] model_type
|
+Optional[str] model_type
|
||||||
+from_file(config_path) Self
|
+load(config_path) Self
|
||||||
+to_file(config_path)
|
+save(config_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
class ModelConfig {
|
class ModelConfig {
|
||||||
|
|
@ -35,9 +30,6 @@ classDiagram
|
||||||
+int n_shared_experts
|
+int n_shared_experts
|
||||||
+int n_activated_experts
|
+int n_activated_experts
|
||||||
+str moe_topk_method
|
+str moe_topk_method
|
||||||
+Optional[int] kv_lora_rank
|
|
||||||
+Optional[int] qk_nope_head_dim
|
|
||||||
+Optional[int] qk_rope_head_dim
|
|
||||||
+load(config_path) ModelConfig
|
+load(config_path) ModelConfig
|
||||||
+save(config_path)
|
+save(config_path)
|
||||||
}
|
}
|
||||||
|
|
@ -49,8 +41,8 @@ classDiagram
|
||||||
+Callable optimizer_fn
|
+Callable optimizer_fn
|
||||||
+Callable scheduler_fn
|
+Callable scheduler_fn
|
||||||
+int n_epoch
|
+int n_epoch
|
||||||
+int batch_per_device
|
+int batch_size
|
||||||
+int grad_accum_steps
|
+int accumulation_steps
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+int start_epoch
|
+int start_epoch
|
||||||
+int start_batch
|
+int start_batch
|
||||||
|
|
@ -77,7 +69,7 @@ classDiagram
|
||||||
class BaseDataset {
|
class BaseDataset {
|
||||||
+int window_size
|
+int window_size
|
||||||
+int stride
|
+int stride
|
||||||
+Optional[BaseStorage] storage
|
+BaseStorage storage
|
||||||
+load(load_path, storage_type, tokenizer)
|
+load(load_path, storage_type, tokenizer)
|
||||||
+__getitem__(index)
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
|
|
@ -134,8 +126,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
+int epoch
|
+int start_epoch
|
||||||
+int iter
|
+int start_iter
|
||||||
}
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
|
|
@ -152,7 +144,6 @@ classDiagram
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
+dict extra
|
+dict extra
|
||||||
+dict meta
|
|
||||||
+save(save_dir)
|
+save(save_dir)
|
||||||
+load(save_dir) Checkpoint
|
+load(save_dir) Checkpoint
|
||||||
}
|
}
|
||||||
|
|
@ -164,7 +155,7 @@ classDiagram
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(model_type) decorator
|
+register(model_type) decorator
|
||||||
+get_component_class(model_type) Type
|
+get_component_class(model_type) Type
|
||||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
+from_pretrained(path, disable_random_init) nn.Module
|
||||||
+save_pretrained(save_directory)
|
+save_pretrained(save_directory)
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
}
|
}
|
||||||
|
|
@ -176,7 +167,7 @@ classDiagram
|
||||||
+ModuleList layers
|
+ModuleList layers
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
+Linear lm_head
|
+Linear lm_head
|
||||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
|
||||||
+load_state_dict(state_dict)
|
+load_state_dict(state_dict)
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
@ -194,7 +185,6 @@ classDiagram
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
+int n_rep
|
+int n_rep
|
||||||
+int layer_id
|
|
||||||
+bool use_qk_norm
|
+bool use_qk_norm
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+Linear q_proj, k_proj, v_proj, o_proj
|
+Linear q_proj, k_proj, v_proj, o_proj
|
||||||
|
|
@ -211,7 +201,6 @@ classDiagram
|
||||||
+int qk_nope_head_dim
|
+int qk_nope_head_dim
|
||||||
+int qk_rope_head_dim
|
+int qk_rope_head_dim
|
||||||
+int n_rep
|
+int n_rep
|
||||||
+int layer_id
|
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||||
+Linear o_proj
|
+Linear o_proj
|
||||||
|
|
@ -226,7 +215,6 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class DeepSeekMoE {
|
class DeepSeekMoE {
|
||||||
+int dim
|
|
||||||
+int n_routed_experts
|
+int n_routed_experts
|
||||||
+int n_shared_experts
|
+int n_shared_experts
|
||||||
+int n_activated_experts
|
+int n_activated_experts
|
||||||
|
|
@ -248,7 +236,6 @@ classDiagram
|
||||||
class RMSNorm {
|
class RMSNorm {
|
||||||
+Parameter weight
|
+Parameter weight
|
||||||
+float norm_eps
|
+float norm_eps
|
||||||
+tuple normalized_shape
|
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -312,6 +299,7 @@ classDiagram
|
||||||
+TrainConfig train_config
|
+TrainConfig train_config
|
||||||
+List[TrainCallback] callbacks
|
+List[TrainCallback] callbacks
|
||||||
+train(checkpoint)
|
+train(checkpoint)
|
||||||
|
+_build_context(checkpoint) TrainContext
|
||||||
+_get_default_callbacks() List[TrainCallback]
|
+_get_default_callbacks() List[TrainCallback]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -336,7 +324,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseStrategy {
|
class BaseStrategy {
|
||||||
+Union[Callable, nn.Module] model
|
+nn.Module model
|
||||||
+str device
|
+str device
|
||||||
+compute_loss(batch) Tensor
|
+compute_loss(batch) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -344,7 +332,7 @@ classDiagram
|
||||||
class StrategyFactory {
|
class StrategyFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, model, device, **kwargs) BaseStrategy
|
+create(model, train_type, device, **kwargs) BaseStrategy
|
||||||
}
|
}
|
||||||
|
|
||||||
class SEQStrategy {
|
class SEQStrategy {
|
||||||
|
|
@ -412,7 +400,7 @@ classDiagram
|
||||||
|
|
||||||
class GradientClippingCallback {
|
class GradientClippingCallback {
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+on_step_begin(context)
|
+on_step_end(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
|
|
@ -471,7 +459,10 @@ classDiagram
|
||||||
+TaskManager _task_mgr
|
+TaskManager _task_mgr
|
||||||
+bool _running
|
+bool _running
|
||||||
+Thread _loop_thread
|
+Thread _loop_thread
|
||||||
|
+int max_batch_size
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
|
+int max_prompt_len
|
||||||
|
+int page_size
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
|
|
@ -509,7 +500,10 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class Storage {
|
class Storage {
|
||||||
|
+int n_layers
|
||||||
+int page_size
|
+int page_size
|
||||||
|
+int head_dim
|
||||||
|
+int n_kv_heads
|
||||||
+Tensor k_cache
|
+Tensor k_cache
|
||||||
+Tensor v_cache
|
+Tensor v_cache
|
||||||
+write(layer_id, page_table, start_pos, k, v)
|
+write(layer_id, page_table, start_pos, k, v)
|
||||||
|
|
@ -681,6 +675,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class AnthropicHandler {
|
class AnthropicHandler {
|
||||||
|
+List[str] stop_sequences
|
||||||
+build_prompt() str
|
+build_prompt() str
|
||||||
+create_response_id() str
|
+create_response_id() str
|
||||||
+on_token(ctx, token, stop_checker) Optional[str]
|
+on_token(ctx, token, stop_checker) Optional[str]
|
||||||
|
|
@ -709,7 +704,7 @@ classDiagram
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class Functions {
|
class Functions {
|
||||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
+spawn_parallel_fn(fn, nprocs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||||
+get_current_device() str
|
+get_current_device() str
|
||||||
+get_world_size() int
|
+get_world_size() int
|
||||||
|
|
@ -756,8 +751,6 @@ classDiagram
|
||||||
ParallelModel <|-- RowParallelLinear
|
ParallelModel <|-- RowParallelLinear
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoModel <|-- Transformer
|
AutoModel <|-- Transformer
|
||||||
BaseConfig <|-- BaseModelConfig
|
|
||||||
BaseConfig <|-- TrainConfig
|
|
||||||
BaseModelConfig <|-- ModelConfig
|
BaseModelConfig <|-- ModelConfig
|
||||||
BaseFactory <|-- AutoModel
|
BaseFactory <|-- AutoModel
|
||||||
BaseFactory <|-- AttnFactory
|
BaseFactory <|-- AttnFactory
|
||||||
|
|
@ -846,7 +839,7 @@ classDiagram
|
||||||
|
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
|
|
@ -885,4 +878,4 @@ classDiagram
|
||||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
|
|
||||||
> Document Update Time: 2026-05-16
|
> Document Update Time: 2026-05-15
|
||||||
|
|
|
||||||
|
|
@ -10,14 +10,14 @@
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
| `--batch_per_device` | Batch size per device | 1 |
|
| `--batch_size` | Batch size | 1 |
|
||||||
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
### Learning Rate Scheduling
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
| `--warmup_steps` | Warmup steps | 1000 |
|
||||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||||
|
|
||||||
|
|
@ -25,8 +25,8 @@
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.95 |
|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.99 |
|
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||||
|
|
||||||
### Data Loading
|
### Data Loading
|
||||||
|
|
@ -60,7 +60,7 @@
|
||||||
| Parameter | Description | Default | Used by |
|
| Parameter | Description | Default | Used by |
|
||||||
|-----------|-------------|---------|---------|
|
|-----------|-------------|---------|---------|
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
|
||||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||||
|
|
@ -69,29 +69,90 @@
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
nohup python scripts/tools/train.py \
|
--data_root_path /path/to/dataset \
|
||||||
--nprocs=4 \
|
--param_path /path/to/model \
|
||||||
--train_type=pt \
|
--n_epoch 3 \
|
||||||
--data_root_path=/path/to/dataset \
|
--batch_size 4 \
|
||||||
--param_path=/path/to/model \
|
--accumulation_steps 8 \
|
||||||
--batch_per_device=4 \
|
--max_lr 3e-4 \
|
||||||
--grad_accum_steps=8 \
|
--warmup_steps 2000 \
|
||||||
--warmup_ratio=0.05 \
|
--max_grad_norm 1.0 \
|
||||||
--max_lr=1e-4 \
|
--ckpt_interval 5000 \
|
||||||
--max_grad_norm=1.0 \
|
--ckpt_dir ./checkpoints \
|
||||||
--adamw_beta1=0.95 \
|
--num_workers 4 \
|
||||||
--adamw_beta2=0.99 \
|
--nprocs 1 \
|
||||||
--adamw_weight_decay=0.01 \
|
--device_type cuda
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
> Document Update Time: 2026-05-16
|
## Generation Parameters
|
||||||
|
|
||||||
|
### GenerationRequest Parameters
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
|-----------|-------------|---------------|
|
||||||
|
| `messages` | List of message dictionaries (role, content) | required |
|
||||||
|
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||||
|
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||||
|
| `top_k` | Top-k sampling count | 50 |
|
||||||
|
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
|
||||||
|
| `stream` | Whether to stream output | False |
|
||||||
|
|
||||||
|
### Usage Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
|
|
||||||
|
# Load model using AutoModel
|
||||||
|
model = AutoModel.from_pretrained("your_model_dir")
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
||||||
|
|
||||||
|
# Create engine with separate model and tokenizer
|
||||||
|
engine = InferenceEngine(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build request with messages format
|
||||||
|
request = GenerationRequest(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50,
|
||||||
|
max_tokens=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate (streaming)
|
||||||
|
for token in engine.generate_with_request(request):
|
||||||
|
print(token, end="", flush=True)
|
||||||
|
|
||||||
|
# Or use simple generate interface
|
||||||
|
result = engine.generate(
|
||||||
|
prompt="Hello",
|
||||||
|
stream=False,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generation Modes
|
||||||
|
|
||||||
|
| Mode | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `stream=True` | Streaming output, yields token by token |
|
||||||
|
| `stream=False` | Non-streaming output, returns complete result |
|
||||||
|
|
||||||
|
> Document Update Time: 2026-05-15
|
||||||
|
|
@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
|
||||||
|
|
||||||
## Training Loop
|
## Training Loop
|
||||||
|
|
||||||
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
Nested loop: **epoch** → **step** (accumulation window) → **batch**.
|
||||||
|
|
||||||
```
|
```
|
||||||
on_train_begin
|
on_train_begin
|
||||||
on_epoch_begin
|
on_epoch_begin
|
||||||
for batch in dataloader:
|
for steps in batched(dataloader, accumulation_steps):
|
||||||
on_batch_begin
|
on_step_begin
|
||||||
loss = strategy(batch)
|
step_batch_nums = len(steps)
|
||||||
(loss / grad_accum_steps).backward()
|
for batch in steps:
|
||||||
iteration += 1
|
on_batch_begin
|
||||||
on_batch_end
|
loss = strategy(batch)
|
||||||
|
(loss / step_batch_nums).backward()
|
||||||
if iteration % grad_accum_steps == 0:
|
iteration += 1
|
||||||
on_step_begin
|
on_batch_end
|
||||||
optimizer.step()
|
on_step_end
|
||||||
optimizer.zero_grad()
|
optimizer.step()
|
||||||
on_step_end
|
optimizer.zero_grad()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
```
|
```
|
||||||
|
|
@ -91,9 +91,9 @@ on_train_end
|
||||||
|
|
||||||
| Hook | Fires | Default callback |
|
| Hook | Fires | Default callback |
|
||||||
|------|-------|-----------------|
|
|------|-------|-----------------|
|
||||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
|
||||||
|
|
||||||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||||||
|
|
||||||
|
|
@ -157,13 +157,12 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
||||||
```
|
```
|
||||||
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
Checkpoint(state_dict, epoch, iteration, extra)
|
||||||
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
|
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
|
||||||
└── load(save_dir) broadcasts metadata from rank-0
|
└── load(save_dir) broadcasts metadata from rank-0
|
||||||
```
|
```
|
||||||
|
|
||||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
|
||||||
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
|
|
||||||
|
|
||||||
## TrainContextBuilder (Builder Pattern)
|
## TrainContextBuilder (Builder Pattern)
|
||||||
|
|
||||||
|
|
@ -184,29 +183,17 @@ context = (
|
||||||
## Training CLI
|
## Training CLI
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
nohup python scripts/tools/train.py \
|
--data_root_path /path/to/data \
|
||||||
--nprocs=4 \
|
--param_path /path/to/model \
|
||||||
--train_type=pt \
|
--batch_size 4 \
|
||||||
--data_root_path=/path/to/dataset \
|
--accumulation_steps 8 \
|
||||||
--param_path=/path/to/model \
|
--max_lr 3e-4 \
|
||||||
--batch_per_device=4 \
|
--warmup_steps 1000 \
|
||||||
--grad_accum_steps=8 \
|
--n_epoch 1
|
||||||
--warmup_ratio=0.05 \
|
|
||||||
--max_lr=1e-4 \
|
|
||||||
--max_grad_norm=1.0 \
|
|
||||||
--adamw_beta1=0.95 \
|
|
||||||
--adamw_beta2=0.99 \
|
|
||||||
--adamw_weight_decay=0.01 \
|
|
||||||
--window_size=2048 \
|
|
||||||
--ckpt_interval=10000 \
|
|
||||||
--ckpt_dir=./checkpoint \
|
|
||||||
--random_seed=3407 \
|
|
||||||
--label_smoothing=0.05 \
|
|
||||||
> out.log 2> err.log &
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-16
|
> Document Update Time: 2026-05-15
|
||||||
|
|
|
||||||
|
|
@ -1,77 +0,0 @@
|
||||||
import json
|
|
||||||
from dataclasses import MISSING, dataclass, fields
|
|
||||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseConfig:
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
d = {}
|
|
||||||
for fld in fields(self):
|
|
||||||
v = getattr(self, fld.name)
|
|
||||||
if isinstance(v, (str, int, float, bool)):
|
|
||||||
d[fld.name] = v
|
|
||||||
elif v is None:
|
|
||||||
d[fld.name] = None
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
try:
|
|
||||||
json.dumps(v)
|
|
||||||
d[fld.name] = v
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
pass
|
|
||||||
return d
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d: Dict[str, Any]) -> Self:
|
|
||||||
hints = get_type_hints(cls)
|
|
||||||
inst = cls.__new__(cls)
|
|
||||||
for fld in fields(cls):
|
|
||||||
if fld.name in d:
|
|
||||||
v = d[fld.name]
|
|
||||||
target = cls._unwrap_optional(hints.get(fld.name))
|
|
||||||
if target is not None:
|
|
||||||
try:
|
|
||||||
v = cls._coerce(v, target)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
pass
|
|
||||||
object.__setattr__(inst, fld.name, v)
|
|
||||||
elif fld.default is not MISSING:
|
|
||||||
object.__setattr__(inst, fld.name, fld.default)
|
|
||||||
elif fld.default_factory is not MISSING:
|
|
||||||
object.__setattr__(inst, fld.name, fld.default_factory())
|
|
||||||
else:
|
|
||||||
object.__setattr__(inst, fld.name, None)
|
|
||||||
return inst
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unwrap_optional(tp) -> Optional[type]:
|
|
||||||
if tp is None:
|
|
||||||
return None
|
|
||||||
origin = getattr(tp, "__origin__", None)
|
|
||||||
if origin is not None:
|
|
||||||
args = getattr(tp, "__args__", ())
|
|
||||||
non_none = [a for a in args if a is not type(None)]
|
|
||||||
return non_none[0] if non_none else None
|
|
||||||
return tp
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _coerce(value: Any, target_type: type) -> Any:
|
|
||||||
if target_type is bool and isinstance(value, bool):
|
|
||||||
return value
|
|
||||||
if (
|
|
||||||
target_type is int
|
|
||||||
and isinstance(value, (int, float))
|
|
||||||
and not isinstance(value, bool)
|
|
||||||
):
|
|
||||||
return int(value)
|
|
||||||
if (
|
|
||||||
target_type is float
|
|
||||||
and isinstance(value, (int, float))
|
|
||||||
and not isinstance(value, bool)
|
|
||||||
):
|
|
||||||
return float(value)
|
|
||||||
if target_type is str and isinstance(value, str):
|
|
||||||
return value
|
|
||||||
if isinstance(value, target_type):
|
|
||||||
return value
|
|
||||||
raise TypeError
|
|
||||||
|
|
@ -1,14 +1,12 @@
|
||||||
import json
|
import json
|
||||||
import warnings
|
import sys
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Any, Dict, Optional, Self
|
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelConfig(BaseConfig):
|
class BaseModelConfig:
|
||||||
"""Field-aware JSON from/to file for dataclass configs.
|
"""Field-aware JSON load/save for dataclass configs.
|
||||||
|
|
||||||
Subclass with additional fields. The base ``model_type`` field
|
Subclass with additional fields. The base ``model_type`` field
|
||||||
enables ``AutoModel`` to pick the correct subclass.
|
enables ``AutoModel`` to pick the correct subclass.
|
||||||
|
|
@ -16,25 +14,76 @@ class BaseModelConfig(BaseConfig):
|
||||||
|
|
||||||
model_type: Optional[str] = None
|
model_type: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
def load(self, config_path: str) -> Self:
|
||||||
def from_file(cls, config_path: str) -> Self:
|
raw: Dict[str, Any] = {}
|
||||||
with open(config_path, "r") as f:
|
with open(config_path, "r") as f:
|
||||||
raw: Dict[str, Any] = json.load(f)
|
raw.update(json.load(f))
|
||||||
|
|
||||||
valid = {fld.name for fld in fields(cls)}
|
hints = get_type_hints(type(self))
|
||||||
for key in list(raw):
|
valid = {fld.name for fld in fields(self)}
|
||||||
|
for key, value in raw.items():
|
||||||
if key not in valid:
|
if key not in valid:
|
||||||
warnings.warn(f"Unknown config key '{key}'")
|
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
|
||||||
del raw[key]
|
continue
|
||||||
|
|
||||||
return cls.from_dict(raw)
|
target_type = self._unwrap_optional(hints.get(key))
|
||||||
|
if target_type is None:
|
||||||
|
continue
|
||||||
|
|
||||||
def to_file(self, config_path: str):
|
try:
|
||||||
d = self.to_dict()
|
value = self._coerce(value, target_type)
|
||||||
config_dict = {k: v for k, v in d.items() if v is not None}
|
except (TypeError, ValueError):
|
||||||
|
sys.stderr.write(
|
||||||
|
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def save(self, config_path: str):
|
||||||
|
config_dict: Dict[str, Any] = {}
|
||||||
|
for fld in fields(self):
|
||||||
|
v = getattr(self, fld.name)
|
||||||
|
if v is not None:
|
||||||
|
config_dict[fld.name] = v
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_dict, f, indent=4)
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_optional(tp: type) -> Optional[type]:
|
||||||
|
if tp is None:
|
||||||
|
return None
|
||||||
|
origin = getattr(tp, "__origin__", None)
|
||||||
|
if origin is not None:
|
||||||
|
args = getattr(tp, "__args__", ())
|
||||||
|
non_none = [a for a in args if a is not type(None)]
|
||||||
|
return non_none[0] if non_none else None
|
||||||
|
return tp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce(value: Any, target_type: type) -> Any:
|
||||||
|
if target_type is bool and isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if (
|
||||||
|
target_type is int
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return int(value)
|
||||||
|
if (
|
||||||
|
target_type is float
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return float(value)
|
||||||
|
if target_type is str and isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if isinstance(value, target_type):
|
||||||
|
return value
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig(BaseModelConfig):
|
class ModelConfig(BaseModelConfig):
|
||||||
|
|
@ -57,11 +106,6 @@ class ModelConfig(BaseModelConfig):
|
||||||
use_qk_norm: Optional[bool] = None
|
use_qk_norm: Optional[bool] = None
|
||||||
use_gated_attention: Optional[bool] = None
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
# MLA
|
|
||||||
kv_lora_rank: Optional[int] = None
|
|
||||||
qk_nope_head_dim: Optional[int] = None
|
|
||||||
qk_rope_head_dim: Optional[int] = None
|
|
||||||
|
|
||||||
# MoE
|
# MoE
|
||||||
ffn_type: str = "mlp"
|
ffn_type: str = "mlp"
|
||||||
n_routed_experts: Optional[int] = None
|
n_routed_experts: Optional[int] = None
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,9 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig(BaseConfig):
|
class TrainConfig:
|
||||||
# basic setting
|
# basic setting
|
||||||
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||||
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||||
|
|
@ -22,10 +20,8 @@ class TrainConfig(BaseConfig):
|
||||||
default=None, metadata={"help": "Scheduler factory for training."}
|
default=None, metadata={"help": "Scheduler factory for training."}
|
||||||
)
|
)
|
||||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||||
batch_per_device: int = field(
|
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
||||||
default=4, metadata={"help": "Batch size per device."}
|
accumulation_steps: int = field(
|
||||||
)
|
|
||||||
grad_accum_steps: int = field(
|
|
||||||
default=1, metadata={"help": "Number of iterations between steps."}
|
default=1, metadata={"help": "Number of iterations between steps."}
|
||||||
)
|
)
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,8 @@ from astrai.dataset.storage import (
|
||||||
H5Storage,
|
H5Storage,
|
||||||
JSONStorage,
|
JSONStorage,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
StorageFactory,
|
available_storage_types,
|
||||||
|
create_storage,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_h5,
|
load_h5,
|
||||||
load_json,
|
load_json,
|
||||||
|
|
@ -25,8 +26,9 @@ __all__ = [
|
||||||
"BaseStorage",
|
"BaseStorage",
|
||||||
"H5Storage",
|
"H5Storage",
|
||||||
"JSONStorage",
|
"JSONStorage",
|
||||||
"StorageFactory",
|
"create_storage",
|
||||||
"detect_format",
|
"detect_format",
|
||||||
|
"available_storage_types",
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
"save_json",
|
"save_json",
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseStorage,
|
BaseStorage,
|
||||||
StorageFactory,
|
create_storage,
|
||||||
detect_format,
|
detect_format,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = StorageFactory.create(storage_type)
|
self.storage = create_storage(storage_type)
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self.storage.load(load_path, tokenizer=tokenizer)
|
||||||
|
|
||||||
def load_json(self, load_path: str, tokenizer=None):
|
def load_json(self, load_path: str, tokenizer=None):
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,6 @@ import h5py
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
|
@ -260,24 +258,6 @@ class BaseStorage(ABC):
|
||||||
return self._fetcher.multi_keys
|
return self._fetcher.multi_keys
|
||||||
|
|
||||||
|
|
||||||
class StorageFactory(BaseFactory["BaseStorage"]):
|
|
||||||
"""Factory for creating storage backends by type name.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@StorageFactory.register("custom")
|
|
||||||
class CustomStorage(BaseStorage):
|
|
||||||
...
|
|
||||||
|
|
||||||
storage = StorageFactory.create("custom")
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, storage_cls: type) -> None:
|
|
||||||
if not issubclass(storage_cls, BaseStorage):
|
|
||||||
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
|
||||||
|
|
||||||
|
|
||||||
@StorageFactory.register("h5")
|
|
||||||
class H5Storage(BaseStorage):
|
class H5Storage(BaseStorage):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
|
|
@ -286,7 +266,6 @@ class H5Storage(BaseStorage):
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
@StorageFactory.register("json")
|
|
||||||
class JSONStorage(BaseStorage):
|
class JSONStorage(BaseStorage):
|
||||||
"""JSON-based storage backend.
|
"""JSON-based storage backend.
|
||||||
|
|
||||||
|
|
@ -299,3 +278,35 @@ class JSONStorage(BaseStorage):
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
segments = load_json(load_path, tokenizer=tokenizer)
|
segments = load_json(load_path, tokenizer=tokenizer)
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
|
_STORAGE_REGISTRY: Dict[str, type] = {
|
||||||
|
"h5": H5Storage,
|
||||||
|
"json": JSONStorage,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_storage(storage_type: str) -> BaseStorage:
|
||||||
|
"""Create a storage instance by type name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: Storage type name ("h5", "json")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the storage type is unknown
|
||||||
|
"""
|
||||||
|
storage_cls = _STORAGE_REGISTRY.get(storage_type)
|
||||||
|
if storage_cls is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown storage type: '{storage_type}'. "
|
||||||
|
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
|
||||||
|
)
|
||||||
|
return storage_cls()
|
||||||
|
|
||||||
|
|
||||||
|
def available_storage_types() -> List[str]:
|
||||||
|
"""Return list of registered storage type names."""
|
||||||
|
return sorted(_STORAGE_REGISTRY.keys())
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""Base factory class for extensible component registration."""
|
"""Base factory class for extensible component registration."""
|
||||||
|
|
||||||
import inspect
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
|
@ -123,10 +122,6 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
def create(cls, name: str, *args, **kwargs) -> T:
|
def create(cls, name: str, *args, **kwargs) -> T:
|
||||||
"""Create a component instance by name.
|
"""Create a component instance by name.
|
||||||
|
|
||||||
Filters kwargs to match the component's __init__ signature,
|
|
||||||
so components don't need to declare **kwargs just to absorb
|
|
||||||
parameters meant for other components.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Registered name of the component
|
name: Registered name of the component
|
||||||
*args: Positional arguments passed to component constructor
|
*args: Positional arguments passed to component constructor
|
||||||
|
|
@ -144,17 +139,6 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||||
)
|
)
|
||||||
component_cls = cls._registry.get(name)
|
component_cls = cls._registry.get(name)
|
||||||
sig = inspect.signature(component_cls.__init__)
|
|
||||||
has_var_kwargs = any(
|
|
||||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
|
||||||
)
|
|
||||||
if not has_var_kwargs:
|
|
||||||
valid = {
|
|
||||||
p.name
|
|
||||||
for p in sig.parameters.values()
|
|
||||||
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
|
||||||
}
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
|
||||||
return component_cls(*args, **kwargs)
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -163,5 +163,4 @@ def run_server(
|
||||||
app,
|
app,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
reload=reload,
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -22,22 +22,14 @@ class InferenceScheduler:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prompt_len: int = 2048,
|
max_prompt_len: int = 512,
|
||||||
page_size: int = 64,
|
page_size: int = 64,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
if max_seq_len is not None:
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
elif config.max_len is not None:
|
|
||||||
self.max_seq_len = config.max_len
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"max_seq_len must be provided either as argument "
|
|
||||||
"or in model config (config.max_len)"
|
|
||||||
)
|
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,9 +60,10 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
model_path = Path(path)
|
model_path = Path(path)
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
|
config = ModelConfig()
|
||||||
config_path = model_path / "config.json"
|
config_path = model_path / "config.json"
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
config = ModelConfig.from_file(str(config_path))
|
config.load(str(config_path))
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
|
@ -88,7 +89,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Save config
|
# Save config
|
||||||
self.config.to_file(str(save_path / "config.json"))
|
self.config.save(str(save_path / "config.json"))
|
||||||
|
|
||||||
# Save weights
|
# Save weights
|
||||||
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ class GQA(nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % n_heads == 0
|
assert dim % n_heads == 0
|
||||||
|
|
@ -122,6 +123,7 @@ class MLA(nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -141,7 +143,7 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
n_kv_heads * (2 * self.head_dim),
|
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = Linear(dim, dim, bias=False)
|
self.o_proj = Linear(dim, dim, bias=False)
|
||||||
|
|
@ -174,7 +176,7 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
q_nope, q_rope = (
|
q_nope, q_rope = (
|
||||||
q[..., : self.qk_nope_head_dim],
|
q[..., : self.qk_nope_head_dim],
|
||||||
q[..., self.qk_nope_head_dim :],
|
q[..., self.qk_rope_head_dim :],
|
||||||
)
|
)
|
||||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,13 @@ class DecoderBlock(nn.Module):
|
||||||
n_heads: int,
|
n_heads: int,
|
||||||
dim_ffn: int,
|
dim_ffn: int,
|
||||||
n_kv_heads: int,
|
n_kv_heads: int,
|
||||||
norm_eps: float,
|
norm_eps: int,
|
||||||
use_qk_norm: bool,
|
use_qk_norm: bool,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
attn_type: str = "gqa",
|
attn_type: str = "gqa",
|
||||||
ffn_type: str = "mlp",
|
ffn_type: str = "mlp",
|
||||||
**kwargs,
|
**moe_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = AttnFactory.create(
|
self.attention = AttnFactory.create(
|
||||||
|
|
@ -34,11 +34,10 @@ class DecoderBlock(nn.Module):
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
use_gated_attention=use_gated_attention,
|
use_gated_attention=use_gated_attention,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
|
||||||
|
|
||||||
@FFNFactory.register("mlp")
|
@FFNFactory.register("mlp")
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, dim: int, dim_ffn: int):
|
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = Linear(dim, dim_ffn)
|
self.up = Linear(dim, dim_feed_forward)
|
||||||
self.gate = Linear(dim, dim_ffn)
|
self.gate = Linear(dim, dim_feed_forward)
|
||||||
self.down = Linear(dim_ffn, dim)
|
self.down = Linear(dim_feed_forward, dim)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
|
|
@ -32,11 +32,12 @@ class DeepSeekMoE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
dim_ffn: int,
|
dim_feed_forward: int,
|
||||||
n_routed_experts: int,
|
n_routed_experts: int,
|
||||||
n_shared_experts: int = 1,
|
n_shared_experts: int = 1,
|
||||||
n_activated_experts: int = 2,
|
n_activated_experts: int = 2,
|
||||||
topk_method: str = "greedy",
|
topk_method: str = "greedy",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -48,10 +49,10 @@ class DeepSeekMoE(nn.Module):
|
||||||
self.router = Linear(dim, n_routed_experts, bias=False)
|
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||||
|
|
||||||
self.shared_experts = nn.ModuleList(
|
self.shared_experts = nn.ModuleList(
|
||||||
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
|
||||||
)
|
)
|
||||||
self.routed_experts = nn.ModuleList(
|
self.routed_experts = nn.ModuleList(
|
||||||
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
|
||||||
|
|
@ -53,13 +53,9 @@ class Transformer(AutoModel):
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
rope_dim = (
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
config.qk_rope_head_dim
|
config.dim // config.n_heads, config.max_len
|
||||||
if config.attn_type == "mla"
|
|
||||||
else config.dim // config.n_heads
|
|
||||||
)
|
)
|
||||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
|
||||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
|
@ -79,9 +75,6 @@ class Transformer(AutoModel):
|
||||||
n_shared_experts=config.n_shared_experts,
|
n_shared_experts=config.n_shared_experts,
|
||||||
n_activated_experts=config.n_activated_experts,
|
n_activated_experts=config.n_activated_experts,
|
||||||
topk_method=config.moe_topk_method,
|
topk_method=config.moe_topk_method,
|
||||||
kv_lora_rank=config.kv_lora_rank,
|
|
||||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
||||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
||||||
)
|
)
|
||||||
for layer_id in range(config.n_layers)
|
for layer_id in range(config.n_layers)
|
||||||
]
|
]
|
||||||
|
|
@ -90,7 +83,7 @@ class Transformer(AutoModel):
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
if self.config.tie_weight is True:
|
if self.config.tie_weight:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
@ -106,7 +99,7 @@ class Transformer(AutoModel):
|
||||||
|
|
||||||
state_dict = dict(state_dict)
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
if self.config.tie_weight is True:
|
if self.config.tie_weight:
|
||||||
# same tensor for embed and lm_head
|
# same tensor for embed and lm_head
|
||||||
if embed_key in state_dict:
|
if embed_key in state_dict:
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
|
|
@ -122,7 +115,7 @@ class Transformer(AutoModel):
|
||||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.tie_weight is True:
|
if self.config.tie_weight:
|
||||||
lm_head_key = prefix + "lm_head.weight"
|
lm_head_key = prefix + "lm_head.weight"
|
||||||
if lm_head_key in state_dict:
|
if lm_head_key in state_dict:
|
||||||
del state_dict[lm_head_key]
|
del state_dict[lm_head_key]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
@ -17,13 +16,11 @@ class Checkpoint:
|
||||||
epoch: int = 0,
|
epoch: int = 0,
|
||||||
iteration: int = 0,
|
iteration: int = 0,
|
||||||
extra: Optional[Dict[str, Any]] = None,
|
extra: Optional[Dict[str, Any]] = None,
|
||||||
meta: Optional[Dict[str, Any]] = None,
|
|
||||||
):
|
):
|
||||||
self.state_dict = state_dict
|
self.state_dict = state_dict
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
self.extra = extra or {}
|
self.extra = extra or {}
|
||||||
self.meta = meta or {}
|
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
|
|
@ -38,16 +35,13 @@ class Checkpoint:
|
||||||
meta = {
|
meta = {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"iteration": self.iteration,
|
"iteration": self.iteration,
|
||||||
"timestamp": time.time(),
|
|
||||||
}
|
}
|
||||||
meta.update(self.meta)
|
|
||||||
with open(save_path / "meta.json", "w") as f:
|
with open(save_path / "meta.json", "w") as f:
|
||||||
json.dump(meta, f, indent=2)
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||||
if self.extra:
|
if self.extra:
|
||||||
for key, value in self.extra.items():
|
torch.save(self.extra, save_path / "extra.pt")
|
||||||
torch.save(value, save_path / f"{key}.pt")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -70,14 +64,14 @@ class Checkpoint:
|
||||||
|
|
||||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||||
|
|
||||||
extra = {}
|
extra = None
|
||||||
for f in save_path.iterdir():
|
extra_path = save_path / "extra.pt"
|
||||||
if f.suffix == ".pt" and f.stem not in ("meta",):
|
if extra_path.exists():
|
||||||
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta["epoch"],
|
epoch=meta["epoch"],
|
||||||
iteration=meta["iteration"],
|
iteration=meta["iteration"],
|
||||||
extra=extra or None,
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,8 @@ class GradientClippingCallback(TrainCallback):
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
def on_step_end(self, context: TrainContext):
|
||||||
|
_ = context
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -89,8 +90,6 @@ class CheckpointCallback(TrainCallback):
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
extra_keys = ("optimizer", "scheduler")
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
|
|
@ -98,14 +97,12 @@ class CheckpointCallback(TrainCallback):
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
self.state_dict_fn = state_dict_fn
|
||||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
self.save_extra_fn = save_extra_fn
|
||||||
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -119,22 +116,17 @@ class CheckpointCallback(TrainCallback):
|
||||||
else context.model.state_dict()
|
else context.model.state_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
extra = self.save_extra_fn(context)
|
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
iteration=context.iteration,
|
iteration=context.iteration,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
meta=context.config.to_dict(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
|
||||||
if context.checkpoint and context.checkpoint.extra:
|
|
||||||
self.load_extra_fn(context.checkpoint.extra, context)
|
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
@ -146,21 +138,6 @@ class CheckpointCallback(TrainCallback):
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save_extra(context: TrainContext) -> dict:
|
|
||||||
extra = {}
|
|
||||||
for name in CheckpointCallback.extra_keys:
|
|
||||||
obj = getattr(context, name, None)
|
|
||||||
if obj:
|
|
||||||
extra[name] = obj.state_dict()
|
|
||||||
return extra
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_extra(extra: dict, context: TrainContext):
|
|
||||||
for name in CheckpointCallback.extra_keys:
|
|
||||||
if name in extra:
|
|
||||||
getattr(context, name).load_state_dict(extra[name])
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("progress_bar")
|
@CallbackFactory.register("progress_bar")
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self
|
from typing import Callable, Optional, Self
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -21,7 +21,6 @@ class TrainContext:
|
||||||
optimizer: Optimizer = field(default=None)
|
optimizer: Optimizer = field(default=None)
|
||||||
scheduler: LRScheduler = field(default=None)
|
scheduler: LRScheduler = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
config: TrainConfig = field(default=None)
|
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
|
|
@ -36,9 +35,11 @@ class TrainContextBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TrainConfig,
|
config: TrainConfig,
|
||||||
|
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._checkpoint: Optional[Checkpoint] = None
|
self._checkpoint: Optional[Checkpoint] = None
|
||||||
|
self._load_extra_fn = load_extra_fn
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
self._checkpoint = checkpoint
|
self._checkpoint = checkpoint
|
||||||
|
|
@ -49,7 +50,6 @@ class TrainContextBuilder:
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
config=self.config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
@ -71,8 +71,11 @@ class TrainContextBuilder:
|
||||||
context.optimizer = self.config.optimizer_fn(context.model)
|
context.optimizer = self.config.optimizer_fn(context.model)
|
||||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
|
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
||||||
|
self._load_extra_fn(self._checkpoint.extra, context)
|
||||||
|
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
sampler_offset = context.iteration * cfg.batch_per_device
|
sampler_offset = context.iteration * cfg.batch_size
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
data_source=cfg.dataset,
|
data_source=cfg.dataset,
|
||||||
start_epoch=context.epoch,
|
start_epoch=context.epoch,
|
||||||
|
|
@ -81,7 +84,7 @@ class TrainContextBuilder:
|
||||||
)
|
)
|
||||||
context.dataloader = DataLoader(
|
context.dataloader = DataLoader(
|
||||||
cfg.dataset,
|
cfg.dataset,
|
||||||
batch_size=cfg.batch_per_device,
|
batch_size=cfg.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
pin_memory=cfg.pin_memory,
|
pin_memory=cfg.pin_memory,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
from itertools import batched
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
|
|
@ -32,6 +33,11 @@ class Trainer:
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
|
return (
|
||||||
|
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||||
|
)
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
|
|
@ -39,47 +45,49 @@ class Trainer:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
cfg = self.train_config
|
config = self.train_config
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(
|
||||||
self._train_impl,
|
self._train_impl,
|
||||||
backend=cfg.backend,
|
backend=config.backend,
|
||||||
world_size=cfg.nprocs,
|
world_size=config.nprocs,
|
||||||
master_addr=cfg.master_addr,
|
master_addr=config.master_addr,
|
||||||
master_port=cfg.master_port,
|
master_port=config.master_port,
|
||||||
device_type=cfg.device_type,
|
device_type=config.device_type,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
|
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||||
cfg = self.train_config
|
context = self._build_context(checkpoint)
|
||||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
|
||||||
self._call_callbacks("on_train_begin", context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
grad_accum_steps = cfg.grad_accum_steps
|
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||||
|
|
||||||
for epoch in range(context.epoch, cfg.n_epoch):
|
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for steps in batched(context.dataloader, accumulation_steps):
|
||||||
self._call_callbacks("on_batch_begin", context)
|
self._call_callbacks("on_step_begin", context)
|
||||||
loss = context.strategy(batch)
|
|
||||||
context.loss = loss.item()
|
|
||||||
stand_loss = loss / grad_accum_steps
|
|
||||||
stand_loss.backward()
|
|
||||||
context.iteration += 1
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
|
||||||
|
|
||||||
if context.iteration % grad_accum_steps == 0:
|
step_batch_nums = len(steps)
|
||||||
self._call_callbacks("on_step_begin", context)
|
for batch in steps:
|
||||||
context.optimizer.step()
|
self._call_callbacks("on_batch_begin", context)
|
||||||
context.optimizer.zero_grad()
|
loss = context.strategy(batch)
|
||||||
self._call_callbacks("on_step_end", context)
|
context.loss = loss.item()
|
||||||
|
context.iteration += 1
|
||||||
|
|
||||||
if context.scheduler:
|
stand_loss = loss / step_batch_nums
|
||||||
context.scheduler.step()
|
stand_loss.backward()
|
||||||
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
|
self._call_callbacks("on_step_end", context)
|
||||||
|
context.optimizer.step()
|
||||||
|
context.optimizer.zero_grad()
|
||||||
|
|
||||||
|
if context.scheduler:
|
||||||
|
context.scheduler.step()
|
||||||
|
|
||||||
self._call_callbacks("on_epoch_end", context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
def generate_text():
|
def generate_text():
|
||||||
|
# Load model from pretrained
|
||||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
@ -21,15 +22,16 @@ def generate_text():
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
for token in engine.generate(
|
response = engine.generate(
|
||||||
prompt=query,
|
prompt=query,
|
||||||
stream=True,
|
stream=False,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
):
|
)
|
||||||
print(token, end="", flush=True)
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -42,20 +42,18 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
|
"--accumulation_steps",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--grad_accum_steps",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of iterations between each optimizer step.",
|
help="Number of iterations between each optimizer step.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warmup_ratio",
|
"--warmup_steps",
|
||||||
type=float,
|
type=int,
|
||||||
default=0.05,
|
default=1000,
|
||||||
help="Fraction of total steps used for LR warmup.",
|
help="Number of warmup steps for LR scheduler.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||||
|
|
@ -69,13 +67,13 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adamw_beta1",
|
"--adamw_beta1",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.95,
|
default=0.9,
|
||||||
help="Beta values for AdamW optimizer.",
|
help="Beta values for AdamW optimizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adamw_beta2",
|
"--adamw_beta2",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.99,
|
default=0.95,
|
||||||
help="Beta values for AdamW optimizer.",
|
help="Beta values for AdamW optimizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -116,7 +114,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_smoothing",
|
"--label_smoothing",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.05,
|
default=0.1,
|
||||||
help="cross_entropy function label smoothing parameter",
|
help="cross_entropy function label smoothing parameter",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -183,34 +181,17 @@ def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
return model.module.state_dict()
|
return model.module.state_dict()
|
||||||
|
|
||||||
|
|
||||||
def compute_total_steps(
|
|
||||||
dataset_len: int,
|
|
||||||
n_epoch: int,
|
|
||||||
batch_per_device: int,
|
|
||||||
nprocs: int,
|
|
||||||
grad_accum_steps: int,
|
|
||||||
) -> int:
|
|
||||||
|
|
||||||
def ceil_div(a: int, b: int) -> int:
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
samples_per_replica = ceil_div(dataset_len, nprocs)
|
|
||||||
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
|
|
||||||
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
|
|
||||||
return total_steps
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_type: str,
|
train_type: str,
|
||||||
param_path: str,
|
param_path: str,
|
||||||
data_root_path: str,
|
data_root_path: str,
|
||||||
max_lr: float,
|
max_lr: float,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
batch_per_device: int,
|
batch_size: int,
|
||||||
start_epoch: int,
|
start_epoch: int,
|
||||||
start_batch: int,
|
start_batch: int,
|
||||||
grad_accum_steps: int,
|
accumulation_steps: int,
|
||||||
warmup_ratio: float,
|
warmup_steps: int,
|
||||||
ckpt_interval: int,
|
ckpt_interval: int,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
|
|
@ -235,8 +216,10 @@ def train(
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
|
config = ModelConfig()
|
||||||
config_path = os.path.join(param_path, "config.json")
|
config_path = os.path.join(param_path, "config.json")
|
||||||
config = ModelConfig.from_file(config_path)
|
if os.path.exists(config_path):
|
||||||
|
config.load(config_path)
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = config.max_len
|
window_size = config.max_len
|
||||||
|
|
@ -277,17 +260,13 @@ def train(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
total_steps = compute_total_steps(
|
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||||
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
|
|
||||||
)
|
|
||||||
warmup_steps = int(warmup_ratio * total_steps)
|
|
||||||
|
|
||||||
scheduler_fn = partial(
|
scheduler_fn = partial(
|
||||||
create_scheduler,
|
create_scheduler,
|
||||||
**{
|
**{
|
||||||
"schedule_type": "cosine",
|
"schedule_type": "cosine",
|
||||||
"warmup_steps": min(warmup_steps, total_steps),
|
"warmup_steps": warmup_steps,
|
||||||
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
|
"lr_decay_steps": total_steps - warmup_steps,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -299,11 +278,11 @@ def train(
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_per_device=batch_per_device,
|
batch_size=batch_size,
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
start_batch=start_batch,
|
start_batch=start_batch,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
grad_accum_steps=grad_accum_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
|
|
||||||
|
|
@ -107,12 +107,12 @@ def test_model():
|
||||||
"""Session-scoped small Transformer model, created once."""
|
"""Session-scoped small Transformer model, created once."""
|
||||||
config = ModelConfig(
|
config = ModelConfig(
|
||||||
vocab_size=1000,
|
vocab_size=1000,
|
||||||
dim=8,
|
dim=16,
|
||||||
n_heads=2,
|
n_heads=4,
|
||||||
n_kv_heads=1,
|
n_kv_heads=2,
|
||||||
dim_ffn=16,
|
dim_ffn=32,
|
||||||
max_len=64,
|
max_len=1024,
|
||||||
n_layers=2,
|
n_layers=4,
|
||||||
norm_eps=1e-5,
|
norm_eps=1e-5,
|
||||||
)
|
)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 8,
|
"dim": 16,
|
||||||
"n_heads": 2,
|
"n_heads": 4,
|
||||||
"n_kv_heads": 1,
|
"n_kv_heads": 2,
|
||||||
"dim_ffn": 16,
|
"dim_ffn": 32,
|
||||||
"max_len": 64,
|
"max_len": 1024,
|
||||||
"n_layers": 2,
|
"n_layers": 4,
|
||||||
"norm_eps": 1e-5,
|
"norm_eps": 1e-5,
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
|
|
|
||||||
|
|
@ -35,33 +35,6 @@ def test_single_process():
|
||||||
assert loaded_checkpoint.iteration == 30
|
assert loaded_checkpoint.iteration == 30
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_with_extra():
|
|
||||||
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
|
||||||
model = torch.nn.Linear(10, 5)
|
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
extra = {
|
|
||||||
"optimizer": optimizer.state_dict(),
|
|
||||||
"scheduler": {"last_epoch": 5},
|
|
||||||
}
|
|
||||||
checkpoint = Checkpoint(
|
|
||||||
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
|
|
||||||
)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
checkpoint.save(tmpdir)
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
|
||||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
|
||||||
|
|
||||||
loaded = Checkpoint.load(tmpdir)
|
|
||||||
assert loaded.extra["scheduler"]["last_epoch"] == 5
|
|
||||||
assert "state" in loaded.extra["optimizer"]
|
|
||||||
|
|
||||||
|
|
||||||
def simple_training():
|
def simple_training():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from astrai.dataset.storage import (
|
||||||
BaseSegmentFetcher,
|
BaseSegmentFetcher,
|
||||||
H5Storage,
|
H5Storage,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
StorageFactory,
|
create_storage,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_json,
|
load_json,
|
||||||
save_h5,
|
save_h5,
|
||||||
|
|
@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env):
|
||||||
|
|
||||||
|
|
||||||
def test_create_storage_invalid_type():
|
def test_create_storage_invalid_type():
|
||||||
"""StorageFactory.create raises ValueError for unknown type"""
|
"""create_storage raises ValueError for unknown type"""
|
||||||
with pytest.raises(ValueError, match="Unknown component"):
|
with pytest.raises(ValueError, match="Unknown storage type"):
|
||||||
StorageFactory.create("parquet")
|
create_storage("parquet")
|
||||||
|
|
||||||
|
|
||||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||||
|
|
|
||||||
|
|
@ -1,108 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
|
||||||
from astrai.model.transformer import Transformer
|
|
||||||
|
|
||||||
TINY_CONFIG = dict(
|
|
||||||
vocab_size=128,
|
|
||||||
dim=8,
|
|
||||||
n_heads=2,
|
|
||||||
n_kv_heads=1,
|
|
||||||
dim_ffn=16,
|
|
||||||
max_len=64,
|
|
||||||
n_layers=2,
|
|
||||||
norm_eps=1e-5,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CONFIGS = [
|
|
||||||
pytest.param(
|
|
||||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
|
|
||||||
id="gqa_mlp",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
**TINY_CONFIG,
|
|
||||||
"attn_type": "mla",
|
|
||||||
"ffn_type": "mlp",
|
|
||||||
"kv_lora_rank": 4,
|
|
||||||
"qk_nope_head_dim": 2,
|
|
||||||
"qk_rope_head_dim": 2,
|
|
||||||
},
|
|
||||||
id="mla_mlp",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
**TINY_CONFIG,
|
|
||||||
"attn_type": "gqa",
|
|
||||||
"ffn_type": "moe",
|
|
||||||
"n_routed_experts": 4,
|
|
||||||
"n_shared_experts": 1,
|
|
||||||
"n_activated_experts": 2,
|
|
||||||
"moe_topk_method": "greedy",
|
|
||||||
},
|
|
||||||
id="gqa_moe",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
**TINY_CONFIG,
|
|
||||||
"attn_type": "gqa",
|
|
||||||
"ffn_type": "mlp",
|
|
||||||
"rope_theta": 100000.0,
|
|
||||||
},
|
|
||||||
id="gqa_rope_theta",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
|
|
||||||
id="gqa_qk_norm",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
|
|
||||||
id="gqa_tie_weight",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
|
||||||
def test_model_forward(config_kwargs):
|
|
||||||
config = ModelConfig(**config_kwargs)
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
model = Transformer(config).to(device=device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
batch_size, seq_len = 2, 8
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(input_ids)
|
|
||||||
|
|
||||||
assert "logits" in output
|
|
||||||
assert "hidden_states" in output
|
|
||||||
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
|
||||||
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
|
|
||||||
assert not torch.isnan(output["logits"]).any()
|
|
||||||
assert not torch.isnan(output["hidden_states"]).any()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
|
||||||
def test_model_forward_with_padding(config_kwargs):
|
|
||||||
config = ModelConfig(**config_kwargs)
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
model = Transformer(config).to(device=device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
batch_size, seq_len = 2, 8
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
||||||
)
|
|
||||||
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
|
||||||
input_mask[:, 4:] = False
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(input_ids, input_mask=input_mask)
|
|
||||||
|
|
||||||
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
|
||||||
assert not torch.isnan(output["logits"]).any()
|
|
||||||
|
|
@ -17,10 +17,10 @@ def transformer_test_env():
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 8,
|
"dim": 128,
|
||||||
"n_heads": 2,
|
"n_heads": 4,
|
||||||
"n_kv_heads": 1,
|
"n_kv_heads": 2,
|
||||||
"dim_ffn": 16,
|
"dim_ffn": 256,
|
||||||
"max_len": 64,
|
"max_len": 64,
|
||||||
"n_layers": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5,
|
"norm_eps": 1e-5,
|
||||||
|
|
@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig.from_file(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
model = Transformer(config)
|
model = Transformer(config)
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
|
|
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig.from_file(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
model = Transformer(config)
|
model = Transformer(config)
|
||||||
|
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
|
|
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig.from_file(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
original_model = Transformer(config)
|
original_model = Transformer(config)
|
||||||
|
|
||||||
st.save_file(original_model.state_dict(), model_path)
|
st.save_file(original_model.state_dict(), model_path)
|
||||||
|
|
||||||
loaded_config = ModelConfig.from_file(config_path)
|
loaded_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = Transformer(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
|
|
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
loaded_config = ModelConfig.from_file(config_path)
|
loaded_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = Transformer(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ def create_train_config(
|
||||||
device: str,
|
device: str,
|
||||||
strategy: str = "seq",
|
strategy: str = "seq",
|
||||||
n_epoch: int = 1,
|
n_epoch: int = 1,
|
||||||
batch_per_device: int = 2,
|
batch_size: int = 2,
|
||||||
grad_accum_steps: int = 1,
|
accumulation_steps: int = 1,
|
||||||
max_grad_norm: float = 1.0,
|
max_grad_norm: float = 1.0,
|
||||||
ckpt_interval: int = 5,
|
ckpt_interval: int = 5,
|
||||||
random_seed: int = 42,
|
random_seed: int = 42,
|
||||||
|
|
@ -47,8 +47,8 @@ def create_train_config(
|
||||||
device: Device type ("cuda" or "cpu")
|
device: Device type ("cuda" or "cpu")
|
||||||
strategy: Training strategy type (default: "seq")
|
strategy: Training strategy type (default: "seq")
|
||||||
n_epoch: Number of epochs (default: 1)
|
n_epoch: Number of epochs (default: 1)
|
||||||
batch_per_device: Batch size per device (default: 2)
|
batch_size: Batch size (default: 2)
|
||||||
grad_accum_steps: Gradient accumulation steps (default: 1)
|
accumulation_steps: Gradient accumulation steps (default: 1)
|
||||||
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
||||||
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
||||||
random_seed: Random seed for reproducibility (default: 42)
|
random_seed: Random seed for reproducibility (default: 42)
|
||||||
|
|
@ -74,9 +74,9 @@ def create_train_config(
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=test_dir,
|
ckpt_dir=test_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_per_device=batch_per_device,
|
batch_size=batch_size,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
grad_accum_steps=grad_accum_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
device_type=device,
|
device_type=device,
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_per_device=2,
|
batch_size=2,
|
||||||
ckpt_interval=3,
|
ckpt_interval=3,
|
||||||
grad_accum_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"],
|
device_type=base_test_env["device"],
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_per_device=2,
|
batch_size=2,
|
||||||
ckpt_interval=1,
|
ckpt_interval=1,
|
||||||
grad_accum_steps=2,
|
accumulation_steps=2,
|
||||||
random_seed=np.random.randint(1e4),
|
random_seed=np.random.randint(1e4),
|
||||||
device_type=base_test_env["device"],
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
||||||
"""Test training with different batch sizes"""
|
"""Test training with different batch sizes"""
|
||||||
batch_sizes = [1, 2, 4, 8]
|
batch_sizes = [1, 2, 4, 8]
|
||||||
|
|
||||||
for batch_per_device in batch_sizes:
|
for batch_size in batch_sizes:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_per_device=batch_per_device,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.batch_per_device == batch_per_device
|
assert train_config.batch_size == batch_size
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
||||||
"""Test training with different gradient accumulation steps"""
|
"""Test training with different gradient accumulation steps"""
|
||||||
grad_accum_steps_list = [1, 2, 4]
|
accumulation_steps_list = [1, 2, 4]
|
||||||
|
|
||||||
for grad_accum_steps in grad_accum_steps_list:
|
for accumulation_steps in accumulation_steps_list:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_per_device=2,
|
batch_size=2,
|
||||||
grad_accum_steps=grad_accum_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
assert train_config.grad_accum_steps == grad_accum_steps
|
assert train_config.accumulation_steps == accumulation_steps
|
||||||
|
|
||||||
|
|
||||||
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
||||||
"""Test training with memory-efficient configurations"""
|
"""Test training with memory-efficient configurations"""
|
||||||
# Test with smaller batch sizes and gradient checkpointing
|
# Test with smaller batch sizes and gradient checkpointing
|
||||||
small_batch_configs = [
|
small_batch_configs = [
|
||||||
{"batch_per_device": 1, "grad_accum_steps": 8},
|
{"batch_size": 1, "accumulation_steps": 8},
|
||||||
{"batch_per_device": 2, "grad_accum_steps": 4},
|
{"batch_size": 2, "accumulation_steps": 4},
|
||||||
{"batch_per_device": 4, "grad_accum_steps": 2},
|
{"batch_size": 4, "accumulation_steps": 2},
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
|
|
@ -54,9 +54,8 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_per_device=config["batch_per_device"],
|
batch_size=config["batch_size"],
|
||||||
grad_accum_steps=config["grad_accum_steps"],
|
accumulation_steps=config["accumulation_steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.grad_accum_steps == config["grad_accum_steps"]
|
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||||
assert train_config.batch_per_device == config["batch_per_device"]
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue