Compare commits

..

No commits in common. "ad9f4d9cf60f35cf742509b8096c7b541252c5be" and "3d12a03909c6dedc6de112a4f53e3ecd1d1a2068" have entirely different histories.

35 changed files with 419 additions and 628 deletions

View File

@ -78,27 +78,15 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
#### Train a Model #### Train a Model
```bash ```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3 CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
nohup python scripts/tools/train.py \ --data_root_path /path/to/dataset \
--nprocs=4 \ --param_path /path/to/model \
--train_type=pt \ --batch_size 4 \
--data_root_path=/path/to/dataset \ --accumulation_steps 8 \
--param_path=/path/to/model \ --max_lr 3e-4 \
--batch_per_device=4 \ --warmup_steps 1000 \
--grad_accum_steps=8 \ --n_epoch 1
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
``` ```
Full reference at [Parameter Guide](assets/docs/params.md). Full reference at [Parameter Guide](assets/docs/params.md).

View File

@ -84,27 +84,15 @@ python scripts/demo/download.py
#### 训练模型 #### 训练模型
```bash ```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3 CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
nohup python scripts/tools/train.py \ --data_root_path /path/to/dataset \
--nprocs=4 \ --param_path /path/to/model \
--train_type=pt \ --batch_size 4 \
--data_root_path=/path/to/dataset \ --accumulation_steps 8 \
--param_path=/path/to/model \ --max_lr 3e-4 \
--batch_per_device=4 \ --warmup_steps 1000 \
--grad_accum_steps=8 \ --n_epoch 1
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
``` ```
完整参数列表见[参数说明](./params.md)。 完整参数列表见[参数说明](./params.md)。

View File

@ -5,15 +5,10 @@
```mermaid ```mermaid
classDiagram classDiagram
namespace config { namespace config {
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
}
class BaseModelConfig { class BaseModelConfig {
+Optional[str] model_type +Optional[str] model_type
+from_file(config_path) Self +load(config_path) Self
+to_file(config_path) +save(config_path)
} }
class ModelConfig { class ModelConfig {
@ -35,9 +30,6 @@ classDiagram
+int n_shared_experts +int n_shared_experts
+int n_activated_experts +int n_activated_experts
+str moe_topk_method +str moe_topk_method
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+load(config_path) ModelConfig +load(config_path) ModelConfig
+save(config_path) +save(config_path)
} }
@ -49,8 +41,8 @@ classDiagram
+Callable optimizer_fn +Callable optimizer_fn
+Callable scheduler_fn +Callable scheduler_fn
+int n_epoch +int n_epoch
+int batch_per_device +int batch_size
+int grad_accum_steps +int accumulation_steps
+float max_grad_norm +float max_grad_norm
+int start_epoch +int start_epoch
+int start_batch +int start_batch
@ -77,7 +69,7 @@ classDiagram
class BaseDataset { class BaseDataset {
+int window_size +int window_size
+int stride +int stride
+Optional[BaseStorage] storage +BaseStorage storage
+load(load_path, storage_type, tokenizer) +load(load_path, storage_type, tokenizer)
+__getitem__(index) +__getitem__(index)
+__len__() +__len__()
@ -134,8 +126,8 @@ classDiagram
} }
class ResumableDistributedSampler { class ResumableDistributedSampler {
+int epoch +int start_epoch
+int iter +int start_iter
} }
class DatasetFactory { class DatasetFactory {
@ -152,7 +144,6 @@ classDiagram
+int epoch +int epoch
+int iteration +int iteration
+dict extra +dict extra
+dict meta
+save(save_dir) +save(save_dir)
+load(save_dir) Checkpoint +load(save_dir) Checkpoint
} }
@ -164,7 +155,7 @@ classDiagram
+Registry _registry +Registry _registry
+register(model_type) decorator +register(model_type) decorator
+get_component_class(model_type) Type +get_component_class(model_type) Type
+from_pretrained(path, disable_random_init, strict) nn.Module +from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory) +save_pretrained(save_directory)
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
} }
@ -176,7 +167,7 @@ classDiagram
+ModuleList layers +ModuleList layers
+RMSNorm norm +RMSNorm norm
+Linear lm_head +Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor] +forward(input_ids, input_mask, paged_cache, position_ids) Dict
+load_state_dict(state_dict) +load_state_dict(state_dict)
+state_dict() +state_dict()
} }
@ -194,7 +185,6 @@ classDiagram
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
+int n_rep +int n_rep
+int layer_id
+bool use_qk_norm +bool use_qk_norm
+bool use_gated_attention +bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj +Linear q_proj, k_proj, v_proj, o_proj
@ -211,7 +201,6 @@ classDiagram
+int qk_nope_head_dim +int qk_nope_head_dim
+int qk_rope_head_dim +int qk_rope_head_dim
+int n_rep +int n_rep
+int layer_id
+bool use_gated_attention +bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj +Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj +Linear o_proj
@ -226,7 +215,6 @@ classDiagram
} }
class DeepSeekMoE { class DeepSeekMoE {
+int dim
+int n_routed_experts +int n_routed_experts
+int n_shared_experts +int n_shared_experts
+int n_activated_experts +int n_activated_experts
@ -248,7 +236,6 @@ classDiagram
class RMSNorm { class RMSNorm {
+Parameter weight +Parameter weight
+float norm_eps +float norm_eps
+tuple normalized_shape
+forward(x) Tensor +forward(x) Tensor
} }
@ -312,6 +299,7 @@ classDiagram
+TrainConfig train_config +TrainConfig train_config
+List[TrainCallback] callbacks +List[TrainCallback] callbacks
+train(checkpoint) +train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback] +_get_default_callbacks() List[TrainCallback]
} }
@ -336,7 +324,7 @@ classDiagram
} }
class BaseStrategy { class BaseStrategy {
+Union[Callable, nn.Module] model +nn.Module model
+str device +str device
+compute_loss(batch) Tensor +compute_loss(batch) Tensor
} }
@ -344,7 +332,7 @@ classDiagram
class StrategyFactory { class StrategyFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
+create(train_type, model, device, **kwargs) BaseStrategy +create(model, train_type, device, **kwargs) BaseStrategy
} }
class SEQStrategy { class SEQStrategy {
@ -412,7 +400,7 @@ classDiagram
class GradientClippingCallback { class GradientClippingCallback {
+float max_grad_norm +float max_grad_norm
+on_step_begin(context) +on_step_end(context)
} }
class CheckpointCallback { class CheckpointCallback {
@ -471,7 +459,10 @@ classDiagram
+TaskManager _task_mgr +TaskManager _task_mgr
+bool _running +bool _running
+Thread _loop_thread +Thread _loop_thread
+int max_batch_size
+int max_seq_len +int max_seq_len
+int max_prompt_len
+int page_size
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
@ -509,7 +500,10 @@ classDiagram
} }
class Storage { class Storage {
+int n_layers
+int page_size +int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache +Tensor k_cache
+Tensor v_cache +Tensor v_cache
+write(layer_id, page_table, start_pos, k, v) +write(layer_id, page_table, start_pos, k, v)
@ -681,6 +675,7 @@ classDiagram
} }
class AnthropicHandler { class AnthropicHandler {
+List[str] stop_sequences
+build_prompt() str +build_prompt() str
+create_response_id() str +create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str] +on_token(ctx, token, stop_checker) Optional[str]
@ -709,7 +704,7 @@ classDiagram
namespace parallel { namespace parallel {
class Functions { class Functions {
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs) +spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str +get_current_device() str
+get_world_size() int +get_world_size() int
@ -756,8 +751,6 @@ classDiagram
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer AutoModel <|-- Transformer
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- ModelConfig BaseModelConfig <|-- ModelConfig
BaseFactory <|-- AutoModel BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory BaseFactory <|-- AttnFactory
@ -846,7 +839,7 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization | | **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
@ -885,4 +878,4 @@ classDiagram
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-16 > Document Update Time: 2026-05-15

View File

@ -10,14 +10,14 @@
| `--data_root_path` | Dataset root directory | required | | `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required | | `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 | | `--n_epoch` | Total training epochs | 1 |
| `--batch_per_device` | Batch size per device | 1 | | `--batch_size` | Batch size | 1 |
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 | | `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling ### Learning Rate Scheduling
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 | | `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 | | `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 | | `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
@ -25,8 +25,8 @@
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.95 | | `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.99 | | `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 | | `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading ### Data Loading
@ -60,7 +60,7 @@
| Parameter | Description | Default | Used by | | Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------| |-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` | | `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` | | `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` | | `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` | | `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` | | `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
@ -69,29 +69,90 @@
### Usage Example ### Usage Example
```bash ```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
nohup python scripts/tools/train.py \ --data_root_path /path/to/dataset \
--nprocs=4 \ --param_path /path/to/model \
--train_type=pt \ --n_epoch 3 \
--data_root_path=/path/to/dataset \ --batch_size 4 \
--param_path=/path/to/model \ --accumulation_steps 8 \
--batch_per_device=4 \ --max_lr 3e-4 \
--grad_accum_steps=8 \ --warmup_steps 2000 \
--warmup_ratio=0.05 \ --max_grad_norm 1.0 \
--max_lr=1e-4 \ --ckpt_interval 5000 \
--max_grad_norm=1.0 \ --ckpt_dir ./checkpoints \
--adamw_beta1=0.95 \ --num_workers 4 \
--adamw_beta2=0.99 \ --nprocs 1 \
--adamw_weight_decay=0.01 \ --device_type cuda
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
``` ```
--- ---
> Document Update Time: 2026-05-16 ## Generation Parameters
### GenerationRequest Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `messages` | List of message dictionaries (role, content) | required |
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
| `stream` | Whether to stream output | False |
### Usage Example
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
# Build request with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_tokens=None,
)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
# Or use simple generate interface
result = engine.generate(
prompt="Hello",
stream=False,
max_tokens=1024,
temperature=0.8,
top_p=0.95,
top_k=50,
)
```
### Generation Modes
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-05-15

View File

@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
## Training Loop ## Training Loop
Two-level loop: **epoch****batch**. Optimizer step fires every `grad_accum_steps` batches. Nested loop: **epoch****step** (accumulation window) → **batch**.
``` ```
on_train_begin on_train_begin
on_epoch_begin on_epoch_begin
for batch in dataloader: for steps in batched(dataloader, accumulation_steps):
on_batch_begin on_step_begin
loss = strategy(batch) step_batch_nums = len(steps)
(loss / grad_accum_steps).backward() for batch in steps:
iteration += 1 on_batch_begin
on_batch_end loss = strategy(batch)
(loss / step_batch_nums).backward()
if iteration % grad_accum_steps == 0: iteration += 1
on_step_begin on_batch_end
optimizer.step() on_step_end
optimizer.zero_grad() optimizer.step()
on_step_end optimizer.zero_grad()
scheduler.step() scheduler.step()
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
@ -91,9 +91,9 @@ on_train_end
| Hook | Fires | Default callback | | Hook | Fires | Default callback |
|------|-------|-----------------| |------|-------|-----------------|
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` | | `on_step_end` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | | `on_train_end` | Training ends | `CheckpointCallback` (final save) |
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`. Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
@ -157,13 +157,12 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Checkpoint ## Checkpoint
``` ```
Checkpoint(state_dict, epoch, iteration, extra, meta) Checkpoint(state_dict, epoch, iteration, extra)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt ├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
└── load(save_dir) broadcasts metadata from rank-0 └── load(save_dir) broadcasts metadata from rank-0
``` ```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`. Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern) ## TrainContextBuilder (Builder Pattern)
@ -184,29 +183,17 @@ context = (
## Training CLI ## Training CLI
```bash ```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
nohup python scripts/tools/train.py \ --data_root_path /path/to/data \
--nprocs=4 \ --param_path /path/to/model \
--train_type=pt \ --batch_size 4 \
--data_root_path=/path/to/dataset \ --accumulation_steps 8 \
--param_path=/path/to/model \ --max_lr 3e-4 \
--batch_per_device=4 \ --warmup_steps 1000 \
--grad_accum_steps=8 \ --n_epoch 1
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
``` ```
Full parameter reference at [params.md](params.md). Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-16 > Document Update Time: 2026-05-15

View File

@ -1,77 +0,0 @@
import json
from dataclasses import MISSING, dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, dict):
try:
json.dumps(v)
d[fld.name] = v
except (TypeError, ValueError):
pass
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError

View File

@ -1,14 +1,12 @@
import json import json
import warnings import sys
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self from typing import Any, Dict, Optional, Self, get_type_hints
from astrai.config.base import BaseConfig
@dataclass @dataclass
class BaseModelConfig(BaseConfig): class BaseModelConfig:
"""Field-aware JSON from/to file for dataclass configs. """Field-aware JSON load/save for dataclass configs.
Subclass with additional fields. The base ``model_type`` field Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass. enables ``AutoModel`` to pick the correct subclass.
@ -16,25 +14,76 @@ class BaseModelConfig(BaseConfig):
model_type: Optional[str] = None model_type: Optional[str] = None
@classmethod def load(self, config_path: str) -> Self:
def from_file(cls, config_path: str) -> Self: raw: Dict[str, Any] = {}
with open(config_path, "r") as f: with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f) raw.update(json.load(f))
valid = {fld.name for fld in fields(cls)} hints = get_type_hints(type(self))
for key in list(raw): valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
if key not in valid: if key not in valid:
warnings.warn(f"Unknown config key '{key}'") sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
del raw[key] continue
return cls.from_dict(raw) target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
def to_file(self, config_path: str): try:
d = self.to_dict() value = self._coerce(value, target_type)
config_dict = {k: v for k, v in d.items() if v is not None} except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4) json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass @dataclass
class ModelConfig(BaseModelConfig): class ModelConfig(BaseModelConfig):
@ -57,11 +106,6 @@ class ModelConfig(BaseModelConfig):
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
# MoE # MoE
ffn_type: str = "mlp" ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None n_routed_experts: Optional[int] = None

View File

@ -6,11 +6,9 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
@dataclass @dataclass
class TrainConfig(BaseConfig): class TrainConfig:
# basic setting # basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."}) model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."}) strategy: str = field(default=None, metadata={"help": "Training strategy."})
@ -22,10 +20,8 @@ class TrainConfig(BaseConfig):
default=None, metadata={"help": "Scheduler factory for training."} default=None, metadata={"help": "Scheduler factory for training."}
) )
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field( batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
default=4, metadata={"help": "Batch size per device."} accumulation_steps: int = field(
)
grad_accum_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."} default=1, metadata={"help": "Number of iterations between steps."}
) )
max_grad_norm: float = field( max_grad_norm: float = field(

View File

@ -9,7 +9,8 @@ from astrai.dataset.storage import (
H5Storage, H5Storage,
JSONStorage, JSONStorage,
MultiSegmentFetcher, MultiSegmentFetcher,
StorageFactory, available_storage_types,
create_storage,
detect_format, detect_format,
load_h5, load_h5,
load_json, load_json,
@ -25,8 +26,9 @@ __all__ = [
"BaseStorage", "BaseStorage",
"H5Storage", "H5Storage",
"JSONStorage", "JSONStorage",
"StorageFactory", "create_storage",
"detect_format", "detect_format",
"available_storage_types",
"save_h5", "save_h5",
"load_h5", "load_h5",
"save_json", "save_json",

View File

@ -9,7 +9,7 @@ from torch.utils.data import Dataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
BaseStorage, BaseStorage,
StorageFactory, create_storage,
detect_format, detect_format,
) )
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC):
""" """
if storage_type is None: if storage_type is None:
storage_type = detect_format(load_path) storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type) self.storage = create_storage(storage_type)
self.storage.load(load_path, tokenizer=tokenizer) self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None): def load_json(self, load_path: str, tokenizer=None):

View File

@ -15,8 +15,6 @@ import h5py
import torch import torch
from torch import Tensor from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
@ -260,24 +258,6 @@ class BaseStorage(ABC):
return self._fetcher.multi_keys return self._fetcher.multi_keys
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@StorageFactory.register("h5")
class H5Storage(BaseStorage): class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data).""" """HDF5-based storage backend (pre-tokenized data)."""
@ -286,7 +266,6 @@ class H5Storage(BaseStorage):
self._fetcher = MultiSegmentFetcher(segments) self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("json")
class JSONStorage(BaseStorage): class JSONStorage(BaseStorage):
"""JSON-based storage backend. """JSON-based storage backend.
@ -299,3 +278,35 @@ class JSONStorage(BaseStorage):
def load(self, load_path: str, tokenizer=None) -> None: def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer) segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments) self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())

View File

@ -1,6 +1,5 @@
"""Base factory class for extensible component registration.""" """Base factory class for extensible component registration."""
import inspect
from abc import ABC from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
@ -123,10 +122,6 @@ class BaseFactory(ABC, Generic[T]):
def create(cls, name: str, *args, **kwargs) -> T: def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name. """Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args: Args:
name: Registered name of the component name: Registered name of the component
*args: Positional arguments passed to component constructor *args: Positional arguments passed to component constructor
@ -144,17 +139,6 @@ class BaseFactory(ABC, Generic[T]):
f"Supported types: {sorted(cls._registry.list_names())}" f"Supported types: {sorted(cls._registry.list_names())}"
) )
component_cls = cls._registry.get(name) component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs) return component_cls(*args, **kwargs)
@classmethod @classmethod

View File

@ -163,5 +163,4 @@ def run_server(
app, app,
host=host, host=host,
port=port, port=port,
reload=reload,
) )

View File

@ -22,22 +22,14 @@ class InferenceScheduler:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 16, max_batch_size: int = 16,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048, max_prompt_len: int = 512,
page_size: int = 64, page_size: int = 64,
device: Optional[str] = None, device: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
config = model.config config = model.config
if max_seq_len is not None: self.max_seq_len = max_seq_len or config.max_len
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype

View File

@ -60,9 +60,10 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path) model_path = Path(path)
# Load config # Load config
config = ModelConfig()
config_path = model_path / "config.json" config_path = model_path / "config.json"
if config_path.exists(): if config_path.exists():
config = ModelConfig.from_file(str(config_path)) config.load(str(config_path))
else: else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
@ -88,7 +89,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
# Save config # Save config
self.config.to_file(str(save_path / "config.json")) self.config.save(str(save_path / "config.json"))
# Save weights # Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors")) st.save_file(self.state_dict(), str(save_path / "model.safetensors"))

View File

@ -40,6 +40,7 @@ class GQA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
assert dim % n_heads == 0 assert dim % n_heads == 0
@ -122,6 +123,7 @@ class MLA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -141,7 +143,7 @@ class MLA(nn.Module):
self.kv_b_proj = Linear( self.kv_b_proj = Linear(
kv_lora_rank, kv_lora_rank,
n_kv_heads * (2 * self.head_dim), n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
) )
self.o_proj = Linear(dim, dim, bias=False) self.o_proj = Linear(dim, dim, bias=False)
@ -174,7 +176,7 @@ class MLA(nn.Module):
q_nope, q_rope = ( q_nope, q_rope = (
q[..., : self.qk_nope_head_dim], q[..., : self.qk_nope_head_dim],
q[..., self.qk_nope_head_dim :], q[..., self.qk_rope_head_dim :],
) )
q_rope = apply_rotary_emb(q_rope, rotary_emb) q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb) k_rope = apply_rotary_emb(k_rope, rotary_emb)

View File

@ -16,13 +16,13 @@ class DecoderBlock(nn.Module):
n_heads: int, n_heads: int,
dim_ffn: int, dim_ffn: int,
n_kv_heads: int, n_kv_heads: int,
norm_eps: float, norm_eps: int,
use_qk_norm: bool, use_qk_norm: bool,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
attn_type: str = "gqa", attn_type: str = "gqa",
ffn_type: str = "mlp", ffn_type: str = "mlp",
**kwargs, **moe_kwargs,
): ):
super().__init__() super().__init__()
self.attention = AttnFactory.create( self.attention = AttnFactory.create(
@ -34,11 +34,10 @@ class DecoderBlock(nn.Module):
norm_eps=norm_eps, norm_eps=norm_eps,
use_gated_attention=use_gated_attention, use_gated_attention=use_gated_attention,
layer_id=layer_id, layer_id=layer_id,
**kwargs,
) )
self.input_norm = RMSNorm(dim, norm_eps) self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps) self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs) self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
def forward( def forward(
self, self,

View File

@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
@FFNFactory.register("mlp") @FFNFactory.register("mlp")
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim: int, dim_ffn: int): def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
super().__init__() super().__init__()
self.up = Linear(dim, dim_ffn) self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_ffn) self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_ffn, dim) self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x)) gated = self.up(x) * F.silu(self.gate(x))
@ -32,11 +32,12 @@ class DeepSeekMoE(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
dim_ffn: int, dim_feed_forward: int,
n_routed_experts: int, n_routed_experts: int,
n_shared_experts: int = 1, n_shared_experts: int = 1,
n_activated_experts: int = 2, n_activated_experts: int = 2,
topk_method: str = "greedy", topk_method: str = "greedy",
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -48,10 +49,10 @@ class DeepSeekMoE(nn.Module):
self.router = Linear(dim, n_routed_experts, bias=False) self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList( self.shared_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)] [MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
) )
self.routed_experts = nn.ModuleList( self.routed_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)] [MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:

View File

@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: float = 10000): def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.max_len = max_len self.max_len = max_len

View File

@ -53,13 +53,9 @@ class Transformer(AutoModel):
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
rope_dim = ( self.rotary_embedding = RotaryEmbedding(
config.qk_rope_head_dim config.dim // config.n_heads, config.max_len
if config.attn_type == "mla"
else config.dim // config.n_heads
) )
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
@ -79,9 +75,6 @@ class Transformer(AutoModel):
n_shared_experts=config.n_shared_experts, n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts, n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method, topk_method=config.moe_topk_method,
kv_lora_rank=config.kv_lora_rank,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
) )
for layer_id in range(config.n_layers) for layer_id in range(config.n_layers)
] ]
@ -90,7 +83,7 @@ class Transformer(AutoModel):
self.norm = RMSNorm(config.dim, config.norm_eps) self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size) self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight is True: if self.config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight self.lm_head.weight = self.embed_tokens.weight
self._init_weights() self._init_weights()
@ -106,7 +99,7 @@ class Transformer(AutoModel):
state_dict = dict(state_dict) state_dict = dict(state_dict)
if self.config.tie_weight is True: if self.config.tie_weight:
# same tensor for embed and lm_head # same tensor for embed and lm_head
if embed_key in state_dict: if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
@ -122,7 +115,7 @@ class Transformer(AutoModel):
destination=destination, prefix=prefix, keep_vars=keep_vars destination=destination, prefix=prefix, keep_vars=keep_vars
) )
if self.config.tie_weight is True: if self.config.tie_weight:
lm_head_key = prefix + "lm_head.weight" lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict: if lm_head_key in state_dict:
del state_dict[lm_head_key] del state_dict[lm_head_key]

View File

@ -1,5 +1,4 @@
import json import json
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -17,13 +16,11 @@ class Checkpoint:
epoch: int = 0, epoch: int = 0,
iteration: int = 0, iteration: int = 0,
extra: Optional[Dict[str, Any]] = None, extra: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
): ):
self.state_dict = state_dict self.state_dict = state_dict
self.epoch = epoch self.epoch = epoch
self.iteration = iteration self.iteration = iteration
self.extra = extra or {} self.extra = extra or {}
self.meta = meta or {}
def save( def save(
self, self,
@ -38,16 +35,13 @@ class Checkpoint:
meta = { meta = {
"epoch": self.epoch, "epoch": self.epoch,
"iteration": self.iteration, "iteration": self.iteration,
"timestamp": time.time(),
} }
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2) json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors") st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra: if self.extra:
for key, value in self.extra.items(): torch.save(self.extra, save_path / "extra.pt")
torch.save(value, save_path / f"{key}.pt")
@classmethod @classmethod
def load( def load(
@ -70,14 +64,14 @@ class Checkpoint:
state_dict = st.load_file(save_path / "state_dict.safetensors") state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = {} extra = None
for f in save_path.iterdir(): extra_path = save_path / "extra.pt"
if f.suffix == ".pt" and f.stem not in ("meta",): if extra_path.exists():
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False) extra = torch.load(extra_path, map_location="cpu", weights_only=False)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
epoch=meta["epoch"], epoch=meta["epoch"],
iteration=meta["iteration"], iteration=meta["iteration"],
extra=extra or None, extra=extra,
) )

View File

@ -79,7 +79,8 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float): def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
def on_step_begin(self, context: TrainContext): def on_step_end(self, context: TrainContext):
_ = context
clip_grad_norm_(context.model.parameters(), self.max_grad_norm) clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -89,8 +90,6 @@ class CheckpointCallback(TrainCallback):
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
extra_keys = ("optimizer", "scheduler")
def __init__( def __init__(
self, self,
save_dir: str, save_dir: str,
@ -98,14 +97,12 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.save_extra_fn = save_extra_fn
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0) @only_on_rank(0)
@ -119,22 +116,17 @@ class CheckpointCallback(TrainCallback):
else context.model.state_dict() else context.model.state_dict()
) )
extra = self.save_extra_fn(context) extra = self.save_extra_fn(context) if self.save_extra_fn else None
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, state_dict=state_dict,
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration, iteration=context.iteration,
extra=extra, extra=extra,
meta=context.config.to_dict(),
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval: if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context) self._save_checkpoint(context)
@ -146,21 +138,6 @@ class CheckpointCallback(TrainCallback):
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
self._save_checkpoint(context) self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar") @CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Self from typing import Callable, Optional, Self
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -21,7 +21,6 @@ class TrainContext:
optimizer: Optimizer = field(default=None) optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None) scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
@ -36,9 +35,11 @@ class TrainContextBuilder:
def __init__( def __init__(
self, self,
config: TrainConfig, config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
): ):
self.config = config self.config = config
self._checkpoint: Optional[Checkpoint] = None self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint self._checkpoint = checkpoint
@ -49,7 +50,6 @@ class TrainContextBuilder:
model=self.config.model, model=self.config.model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=self.config,
) )
device = get_current_device() device = get_current_device()
@ -71,8 +71,11 @@ class TrainContextBuilder:
context.optimizer = self.config.optimizer_fn(context.model) context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer) context.scheduler = self.config.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
cfg = self.config cfg = self.config
sampler_offset = context.iteration * cfg.batch_per_device sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler( sampler = ResumableDistributedSampler(
data_source=cfg.dataset, data_source=cfg.dataset,
start_epoch=context.epoch, start_epoch=context.epoch,
@ -81,7 +84,7 @@ class TrainContextBuilder:
) )
context.dataloader = DataLoader( context.dataloader = DataLoader(
cfg.dataset, cfg.dataset,
batch_size=cfg.batch_per_device, batch_size=cfg.batch_size,
sampler=sampler, sampler=sampler,
num_workers=cfg.num_workers, num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory, pin_memory=cfg.pin_memory,

View File

@ -1,4 +1,5 @@
import logging import logging
from itertools import batched
from typing import List, Optional from typing import List, Optional
from astrai.config import TrainConfig from astrai.config import TrainConfig
@ -32,6 +33,11 @@ class Trainer:
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
] ]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks: for callback in self.callbacks:
method = getattr(callback, method_name, None) method = getattr(callback, method_name, None)
@ -39,47 +45,49 @@ class Trainer:
method(context) method(context)
def train(self, checkpoint: Optional[Checkpoint] = None): def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config config = self.train_config
spawn_parallel_fn( spawn_parallel_fn(
self._train_impl, self._train_impl,
backend=cfg.backend, backend=config.backend,
world_size=cfg.nprocs, world_size=config.nprocs,
master_addr=cfg.master_addr, master_addr=config.master_addr,
master_port=cfg.master_port, master_port=config.master_port,
device_type=cfg.device_type, device_type=config.device_type,
checkpoint=checkpoint, checkpoint=checkpoint,
) )
def _train_impl(self, checkpoint: Optional[Checkpoint] = None): def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
cfg = self.train_config context = self._build_context(checkpoint)
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
try: try:
context.model.train() context.model.train()
grad_accum_steps = cfg.grad_accum_steps accumulation_steps = max(self.train_config.accumulation_steps, 1)
for epoch in range(context.epoch, cfg.n_epoch): for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch context.epoch = epoch
self._call_callbacks("on_epoch_begin", context) self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader: for steps in batched(context.dataloader, accumulation_steps):
self._call_callbacks("on_batch_begin", context) self._call_callbacks("on_step_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / grad_accum_steps
stand_loss.backward()
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if context.iteration % grad_accum_steps == 0: step_batch_nums = len(steps)
self._call_callbacks("on_step_begin", context) for batch in steps:
context.optimizer.step() self._call_callbacks("on_batch_begin", context)
context.optimizer.zero_grad() loss = context.strategy(batch)
self._call_callbacks("on_step_end", context) context.loss = loss.item()
context.iteration += 1
if context.scheduler: stand_loss = loss / step_batch_nums
context.scheduler.step() stand_loss.backward()
self._call_callbacks("on_batch_end", context)
self._call_callbacks("on_step_end", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context) self._call_callbacks("on_epoch_end", context)

View File

@ -11,6 +11,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text(): def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT) model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
@ -21,15 +22,16 @@ def generate_text():
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
for token in engine.generate( response = engine.generate(
prompt=query, prompt=query,
stream=True, stream=False,
max_tokens=2048, max_tokens=2048,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50, top_k=50,
): )
print(token, end="", flush=True)
print(response)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -42,20 +42,18 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train." "--n_epoch", type=int, default=1, help="Number of epochs to train."
) )
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
parser.add_argument( parser.add_argument(
"--batch_per_device", type=int, default=1, help="Batch size per GPU." "--accumulation_steps",
)
parser.add_argument(
"--grad_accum_steps",
type=int, type=int,
default=1, default=1,
help="Number of iterations between each optimizer step.", help="Number of iterations between each optimizer step.",
) )
parser.add_argument( parser.add_argument(
"--warmup_ratio", "--warmup_steps",
type=float, type=int,
default=0.05, default=1000,
help="Fraction of total steps used for LR warmup.", help="Number of warmup steps for LR scheduler.",
) )
parser.add_argument( parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training." "--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -69,13 +67,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--adamw_beta1", "--adamw_beta1",
type=float, type=float,
default=0.95, default=0.9,
help="Beta values for AdamW optimizer.", help="Beta values for AdamW optimizer.",
) )
parser.add_argument( parser.add_argument(
"--adamw_beta2", "--adamw_beta2",
type=float, type=float,
default=0.99, default=0.95,
help="Beta values for AdamW optimizer.", help="Beta values for AdamW optimizer.",
) )
parser.add_argument( parser.add_argument(
@ -116,7 +114,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--label_smoothing", "--label_smoothing",
type=float, type=float,
default=0.05, default=0.1,
help="cross_entropy function label smoothing parameter", help="cross_entropy function label smoothing parameter",
) )
@ -183,34 +181,17 @@ def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict() return model.module.state_dict()
def compute_total_steps(
dataset_len: int,
n_epoch: int,
batch_per_device: int,
nprocs: int,
grad_accum_steps: int,
) -> int:
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
return total_steps
def train( def train(
train_type: str, train_type: str,
param_path: str, param_path: str,
data_root_path: str, data_root_path: str,
max_lr: float, max_lr: float,
n_epoch: int, n_epoch: int,
batch_per_device: int, batch_size: int,
start_epoch: int, start_epoch: int,
start_batch: int, start_batch: int,
grad_accum_steps: int, accumulation_steps: int,
warmup_ratio: float, warmup_steps: int,
ckpt_interval: int, ckpt_interval: int,
ckpt_dir: str, ckpt_dir: str,
dpo_beta: float, dpo_beta: float,
@ -235,8 +216,10 @@ def train(
assert os.path.exists(param_path) assert os.path.exists(param_path)
# Load config # Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")
config = ModelConfig.from_file(config_path) if os.path.exists(config_path):
config.load(config_path)
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
@ -277,17 +260,13 @@ def train(
}, },
) )
total_steps = compute_total_steps( total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
)
warmup_steps = int(warmup_ratio * total_steps)
scheduler_fn = partial( scheduler_fn = partial(
create_scheduler, create_scheduler,
**{ **{
"schedule_type": "cosine", "schedule_type": "cosine",
"warmup_steps": min(warmup_steps, total_steps), "warmup_steps": warmup_steps,
"lr_decay_steps": total_steps - min(warmup_steps, total_steps), "lr_decay_steps": total_steps - warmup_steps,
}, },
) )
@ -299,11 +278,11 @@ def train(
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir, ckpt_dir=ckpt_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_per_device=batch_per_device, batch_size=batch_size,
start_epoch=start_epoch, start_epoch=start_epoch,
start_batch=start_batch, start_batch=start_batch,
ckpt_interval=ckpt_interval, ckpt_interval=ckpt_interval,
grad_accum_steps=grad_accum_steps, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
num_workers=num_workers, num_workers=num_workers,

View File

@ -107,12 +107,12 @@ def test_model():
"""Session-scoped small Transformer model, created once.""" """Session-scoped small Transformer model, created once."""
config = ModelConfig( config = ModelConfig(
vocab_size=1000, vocab_size=1000,
dim=8, dim=16,
n_heads=2, n_heads=4,
n_kv_heads=1, n_kv_heads=2,
dim_ffn=16, dim_ffn=32,
max_len=64, max_len=1024,
n_layers=2, n_layers=4,
norm_eps=1e-5, norm_eps=1e-5,
) )
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
json.dump( json.dump(
{ {
"vocab_size": 1000, "vocab_size": 1000,
"dim": 8, "dim": 16,
"n_heads": 2, "n_heads": 4,
"n_kv_heads": 1, "n_kv_heads": 2,
"dim_ffn": 16, "dim_ffn": 32,
"max_len": 64, "max_len": 1024,
"n_layers": 2, "n_layers": 4,
"norm_eps": 1e-5, "norm_eps": 1e-5,
}, },
f, f,

View File

@ -35,33 +35,6 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.iteration == 30
def test_checkpoint_with_extra():
"""Verify extra keys are saved as individual .pt files and loaded back."""
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
extra = {
"optimizer": optimizer.state_dict(),
"scheduler": {"last_epoch": 5},
}
checkpoint = Checkpoint(
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
import os
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
loaded = Checkpoint.load(tmpdir)
assert loaded.extra["scheduler"]["last_epoch"] == 5
assert "state" in loaded.extra["optimizer"]
def simple_training(): def simple_training():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)

View File

@ -10,7 +10,7 @@ from astrai.dataset.storage import (
BaseSegmentFetcher, BaseSegmentFetcher,
H5Storage, H5Storage,
MultiSegmentFetcher, MultiSegmentFetcher,
StorageFactory, create_storage,
detect_format, detect_format,
load_json, load_json,
save_h5, save_h5,
@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env):
def test_create_storage_invalid_type(): def test_create_storage_invalid_type():
"""StorageFactory.create raises ValueError for unknown type""" """create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"): with pytest.raises(ValueError, match="Unknown storage type"):
StorageFactory.create("parquet") create_storage("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env): def test_json_pretokenized_without_tokenizer(base_test_env):

View File

@ -1,108 +0,0 @@
import pytest
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
CONFIGS = [
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
id="gqa_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "mla",
"ffn_type": "mlp",
"kv_lora_rank": 4,
"qk_nope_head_dim": 2,
"qk_rope_head_dim": 2,
},
id="mla_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "moe",
"n_routed_experts": 4,
"n_shared_experts": 1,
"n_activated_experts": 2,
"moe_topk_method": "greedy",
},
id="gqa_moe",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "mlp",
"rope_theta": 100000.0,
},
id="gqa_rope_theta",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
id="gqa_qk_norm",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
id="gqa_tie_weight",
),
]
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert "logits" in output
assert "hidden_states" in output
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
assert not torch.isnan(output["logits"]).any()
assert not torch.isnan(output["hidden_states"]).any()
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert not torch.isnan(output["logits"]).any()

View File

@ -17,10 +17,10 @@ def transformer_test_env():
config = { config = {
"vocab_size": 1000, "vocab_size": 1000,
"dim": 8, "dim": 128,
"n_heads": 2, "n_heads": 4,
"n_kv_heads": 1, "n_kv_heads": 2,
"dim_ffn": 16, "dim_ffn": 256,
"max_len": 64, "max_len": 64,
"n_layers": 2, "n_layers": 2,
"norm_eps": 1e-5, "norm_eps": 1e-5,
@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = ModelConfig().load(config_path)
model = Transformer(config) model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = ModelConfig().load(config_path)
model = Transformer(config) model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = ModelConfig().load(config_path)
original_model = Transformer(config) original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path) st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig.from_file(config_path) loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config) model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
loaded_config = ModelConfig.from_file(config_path) loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config) model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))

View File

@ -31,8 +31,8 @@ def create_train_config(
device: str, device: str,
strategy: str = "seq", strategy: str = "seq",
n_epoch: int = 1, n_epoch: int = 1,
batch_per_device: int = 2, batch_size: int = 2,
grad_accum_steps: int = 1, accumulation_steps: int = 1,
max_grad_norm: float = 1.0, max_grad_norm: float = 1.0,
ckpt_interval: int = 5, ckpt_interval: int = 5,
random_seed: int = 42, random_seed: int = 42,
@ -47,8 +47,8 @@ def create_train_config(
device: Device type ("cuda" or "cpu") device: Device type ("cuda" or "cpu")
strategy: Training strategy type (default: "seq") strategy: Training strategy type (default: "seq")
n_epoch: Number of epochs (default: 1) n_epoch: Number of epochs (default: 1)
batch_per_device: Batch size per device (default: 2) batch_size: Batch size (default: 2)
grad_accum_steps: Gradient accumulation steps (default: 1) accumulation_steps: Gradient accumulation steps (default: 1)
max_grad_norm: Maximum gradient norm for clipping (default: 1.0) max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
ckpt_interval: Checkpoint save interval in iterations (default: 5) ckpt_interval: Checkpoint save interval in iterations (default: 5)
random_seed: Random seed for reproducibility (default: 42) random_seed: Random seed for reproducibility (default: 42)
@ -74,9 +74,9 @@ def create_train_config(
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
ckpt_dir=test_dir, ckpt_dir=test_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_per_device=batch_per_device, batch_size=batch_size,
ckpt_interval=ckpt_interval, ckpt_interval=ckpt_interval,
grad_accum_steps=grad_accum_steps, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
device_type=device, device_type=device,

View File

@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_per_device=2, batch_size=2,
ckpt_interval=3, ckpt_interval=3,
grad_accum_steps=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42, random_seed=42,
device_type=base_test_env["device"], device_type=base_test_env["device"],

View File

@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=2, n_epoch=2,
batch_per_device=2, batch_size=2,
ckpt_interval=1, ckpt_interval=1,
grad_accum_steps=2, accumulation_steps=2,
random_seed=np.random.randint(1e4), random_seed=np.random.randint(1e4),
device_type=base_test_env["device"], device_type=base_test_env["device"],
) )

View File

@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
"""Test training with different batch sizes""" """Test training with different batch sizes"""
batch_sizes = [1, 2, 4, 8] batch_sizes = [1, 2, 4, 8]
for batch_per_device in batch_sizes: for batch_size in batch_sizes:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
batch_per_device=batch_per_device, batch_size=batch_size,
) )
assert train_config.batch_per_device == batch_per_device assert train_config.batch_size == batch_size
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory): def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
"""Test training with different gradient accumulation steps""" """Test training with different gradient accumulation steps"""
grad_accum_steps_list = [1, 2, 4] accumulation_steps_list = [1, 2, 4]
for grad_accum_steps in grad_accum_steps_list: for accumulation_steps in accumulation_steps_list:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
batch_per_device=2, batch_size=2,
grad_accum_steps=grad_accum_steps, accumulation_steps=accumulation_steps,
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train() trainer.train()
assert train_config.grad_accum_steps == grad_accum_steps assert train_config.accumulation_steps == accumulation_steps
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory): def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
"""Test training with memory-efficient configurations""" """Test training with memory-efficient configurations"""
# Test with smaller batch sizes and gradient checkpointing # Test with smaller batch sizes and gradient checkpointing
small_batch_configs = [ small_batch_configs = [
{"batch_per_device": 1, "grad_accum_steps": 8}, {"batch_size": 1, "accumulation_steps": 8},
{"batch_per_device": 2, "grad_accum_steps": 4}, {"batch_size": 2, "accumulation_steps": 4},
{"batch_per_device": 4, "grad_accum_steps": 2}, {"batch_size": 4, "accumulation_steps": 2},
] ]
for config in small_batch_configs: for config in small_batch_configs:
@ -54,9 +54,8 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
batch_per_device=config["batch_per_device"], batch_size=config["batch_size"],
grad_accum_steps=config["grad_accum_steps"], accumulation_steps=config["accumulation_steps"],
) )
assert train_config.grad_accum_steps == config["grad_accum_steps"] assert train_config.accumulation_steps == config["accumulation_steps"]
assert train_config.batch_per_device == config["batch_per_device"]