Compare commits

..

No commits in common. "ad9f4d9cf60f35cf742509b8096c7b541252c5be" and "3d12a03909c6dedc6de112a4f53e3ecd1d1a2068" have entirely different histories.

35 changed files with 419 additions and 628 deletions

View File

@ -78,27 +78,15 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
#### Train a Model
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
```
Full reference at [Parameter Guide](assets/docs/params.md).

View File

@ -84,27 +84,15 @@ python scripts/demo/download.py
#### 训练模型
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
```
完整参数列表见[参数说明](./params.md)。

View File

@ -5,15 +5,10 @@
```mermaid
classDiagram
namespace config {
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
}
class BaseModelConfig {
+Optional[str] model_type
+from_file(config_path) Self
+to_file(config_path)
+load(config_path) Self
+save(config_path)
}
class ModelConfig {
@ -35,9 +30,6 @@ classDiagram
+int n_shared_experts
+int n_activated_experts
+str moe_topk_method
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+load(config_path) ModelConfig
+save(config_path)
}
@ -49,8 +41,8 @@ classDiagram
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_per_device
+int grad_accum_steps
+int batch_size
+int accumulation_steps
+float max_grad_norm
+int start_epoch
+int start_batch
@ -77,7 +69,7 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+Optional[BaseStorage] storage
+BaseStorage storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
@ -134,8 +126,8 @@ classDiagram
}
class ResumableDistributedSampler {
+int epoch
+int iter
+int start_epoch
+int start_iter
}
class DatasetFactory {
@ -152,7 +144,6 @@ classDiagram
+int epoch
+int iteration
+dict extra
+dict meta
+save(save_dir)
+load(save_dir) Checkpoint
}
@ -164,7 +155,7 @@ classDiagram
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
+from_pretrained(path, disable_random_init, strict) nn.Module
+from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
@ -176,7 +167,7 @@ classDiagram
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
+load_state_dict(state_dict)
+state_dict()
}
@ -194,7 +185,6 @@ classDiagram
+int n_kv_heads
+int head_dim
+int n_rep
+int layer_id
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj
@ -211,7 +201,6 @@ classDiagram
+int qk_nope_head_dim
+int qk_rope_head_dim
+int n_rep
+int layer_id
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
@ -226,7 +215,6 @@ classDiagram
}
class DeepSeekMoE {
+int dim
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
@ -248,7 +236,6 @@ classDiagram
class RMSNorm {
+Parameter weight
+float norm_eps
+tuple normalized_shape
+forward(x) Tensor
}
@ -312,6 +299,7 @@ classDiagram
+TrainConfig train_config
+List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback]
}
@ -336,7 +324,7 @@ classDiagram
}
class BaseStrategy {
+Union[Callable, nn.Module] model
+nn.Module model
+str device
+compute_loss(batch) Tensor
}
@ -344,7 +332,7 @@ classDiagram
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(train_type, model, device, **kwargs) BaseStrategy
+create(model, train_type, device, **kwargs) BaseStrategy
}
class SEQStrategy {
@ -412,7 +400,7 @@ classDiagram
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
+on_step_end(context)
}
class CheckpointCallback {
@ -471,7 +459,10 @@ classDiagram
+TaskManager _task_mgr
+bool _running
+Thread _loop_thread
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
@ -509,7 +500,10 @@ classDiagram
}
class Storage {
+int n_layers
+int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache
+Tensor v_cache
+write(layer_id, page_table, start_pos, k, v)
@ -681,6 +675,7 @@ classDiagram
}
class AnthropicHandler {
+List[str] stop_sequences
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
@ -709,7 +704,7 @@ classDiagram
namespace parallel {
class Functions {
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
@ -756,8 +751,6 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- ModelConfig
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
@ -846,7 +839,7 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
@ -885,4 +878,4 @@ classDiagram
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-16
> Document Update Time: 2026-05-15

View File

@ -10,14 +10,14 @@
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_per_device` | Batch size per device | 1 |
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
| `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
@ -25,8 +25,8 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.95 |
| `--adamw_beta2` | AdamW beta2 | 0.99 |
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
@ -60,7 +60,7 @@
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
@ -69,29 +69,90 @@
### Usage Example
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
```
---
> Document Update Time: 2026-05-16
## Generation Parameters
### GenerationRequest Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `messages` | List of message dictionaries (role, content) | required |
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
| `stream` | Whether to stream output | False |
### Usage Example
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
# Build request with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_tokens=None,
)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
# Or use simple generate interface
result = engine.generate(
prompt="Hello",
stream=False,
max_tokens=1024,
temperature=0.8,
top_p=0.95,
top_k=50,
)
```
### Generation Modes
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-05-15

View File

@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
## Training Loop
Two-level loop: **epoch****batch**. Optimizer step fires every `grad_accum_steps` batches.
Nested loop: **epoch****step** (accumulation window) → **batch**.
```
on_train_begin
on_epoch_begin
for batch in dataloader:
on_batch_begin
loss = strategy(batch)
(loss / grad_accum_steps).backward()
iteration += 1
on_batch_end
if iteration % grad_accum_steps == 0:
on_step_begin
optimizer.step()
optimizer.zero_grad()
on_step_end
scheduler.step()
for steps in batched(dataloader, accumulation_steps):
on_step_begin
step_batch_nums = len(steps)
for batch in steps:
on_batch_begin
loss = strategy(batch)
(loss / step_batch_nums).backward()
iteration += 1
on_batch_end
on_step_end
optimizer.step()
optimizer.zero_grad()
scheduler.step()
on_epoch_end
on_train_end
```
@ -91,9 +91,9 @@ on_train_end
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
@ -157,13 +157,12 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
Checkpoint(state_dict, epoch, iteration, extra)
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
└── load(save_dir) broadcasts metadata from rank-0
```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
## TrainContextBuilder (Builder Pattern)
@ -184,29 +183,17 @@ context = (
## Training CLI
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/data \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
```
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-16
> Document Update Time: 2026-05-15

View File

@ -1,77 +0,0 @@
import json
from dataclasses import MISSING, dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, dict):
try:
json.dumps(v)
d[fld.name] = v
except (TypeError, ValueError):
pass
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError

View File

@ -1,14 +1,12 @@
import json
import warnings
import sys
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
class BaseModelConfig(BaseConfig):
"""Field-aware JSON from/to file for dataclass configs.
class BaseModelConfig:
"""Field-aware JSON load/save for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
@ -16,25 +14,76 @@ class BaseModelConfig(BaseConfig):
model_type: Optional[str] = None
@classmethod
def from_file(cls, config_path: str) -> Self:
def load(self, config_path: str) -> Self:
raw: Dict[str, Any] = {}
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
raw.update(json.load(f))
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
hints = get_type_hints(type(self))
valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
if key not in valid:
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
continue
return cls.from_dict(raw)
target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
try:
value = self._coerce(value, target_type)
except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass
class ModelConfig(BaseModelConfig):
@ -57,11 +106,6 @@ class ModelConfig(BaseModelConfig):
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
# MoE
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None

View File

@ -6,11 +6,9 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
@dataclass
class TrainConfig(BaseConfig):
class TrainConfig:
# basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})
@ -22,10 +20,8 @@ class TrainConfig(BaseConfig):
default=None, metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field(
default=4, metadata={"help": "Batch size per device."}
)
grad_accum_steps: int = field(
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(

View File

@ -9,7 +9,8 @@ from astrai.dataset.storage import (
H5Storage,
JSONStorage,
MultiSegmentFetcher,
StorageFactory,
available_storage_types,
create_storage,
detect_format,
load_h5,
load_json,
@ -25,8 +26,9 @@ __all__ = [
"BaseStorage",
"H5Storage",
"JSONStorage",
"StorageFactory",
"create_storage",
"detect_format",
"available_storage_types",
"save_h5",
"load_h5",
"save_json",

View File

@ -9,7 +9,7 @@ from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
StorageFactory,
create_storage,
detect_format,
)
from astrai.factory import BaseFactory
@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC):
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type)
self.storage = create_storage(storage_type)
self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None):

View File

@ -15,8 +15,6 @@ import h5py
import torch
from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
@ -260,24 +258,6 @@ class BaseStorage(ABC):
return self._fetcher.multi_keys
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@StorageFactory.register("h5")
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
@ -286,7 +266,6 @@ class H5Storage(BaseStorage):
self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("json")
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
@ -299,3 +278,35 @@ class JSONStorage(BaseStorage):
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())

View File

@ -1,6 +1,5 @@
"""Base factory class for extensible component registration."""
import inspect
from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
@ -123,10 +122,6 @@ class BaseFactory(ABC, Generic[T]):
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
@ -144,17 +139,6 @@ class BaseFactory(ABC, Generic[T]):
f"Supported types: {sorted(cls._registry.list_names())}"
)
component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs)
@classmethod

View File

@ -163,5 +163,4 @@ def run_server(
app,
host=host,
port=port,
reload=reload,
)

View File

@ -22,22 +22,14 @@ class InferenceScheduler:
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
max_prompt_len: int = 512,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.max_seq_len = max_seq_len or config.max_len
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype

View File

@ -60,9 +60,10 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path)
# Load config
config = ModelConfig()
config_path = model_path / "config.json"
if config_path.exists():
config = ModelConfig.from_file(str(config_path))
config.load(str(config_path))
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
@ -88,7 +89,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
save_path.mkdir(parents=True, exist_ok=True)
# Save config
self.config.to_file(str(save_path / "config.json"))
self.config.save(str(save_path / "config.json"))
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))

View File

@ -40,6 +40,7 @@ class GQA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
assert dim % n_heads == 0
@ -122,6 +123,7 @@ class MLA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
self.dim = dim
@ -141,7 +143,7 @@ class MLA(nn.Module):
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (2 * self.head_dim),
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
)
self.o_proj = Linear(dim, dim, bias=False)
@ -174,7 +176,7 @@ class MLA(nn.Module):
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_nope_head_dim :],
q[..., self.qk_rope_head_dim :],
)
q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb)

View File

@ -16,13 +16,13 @@ class DecoderBlock(nn.Module):
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: float,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**kwargs,
**moe_kwargs,
):
super().__init__()
self.attention = AttnFactory.create(
@ -34,11 +34,10 @@ class DecoderBlock(nn.Module):
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
**kwargs,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
def forward(
self,

View File

@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_ffn: int):
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
super().__init__()
self.up = Linear(dim, dim_ffn)
self.gate = Linear(dim, dim_ffn)
self.down = Linear(dim_ffn, dim)
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
@ -32,11 +32,12 @@ class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_ffn: int,
dim_feed_forward: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
**kwargs,
):
super().__init__()
self.dim = dim
@ -48,10 +49,10 @@ class DeepSeekMoE(nn.Module):
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:

View File

@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: float = 10000):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len

View File

@ -53,13 +53,9 @@ class Transformer(AutoModel):
def __init__(self, config: ModelConfig):
super().__init__(config)
self.config = config
rope_dim = (
config.qk_rope_head_dim
if config.attn_type == "mla"
else config.dim // config.n_heads
self.rotary_embedding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
)
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
@ -79,9 +75,6 @@ class Transformer(AutoModel):
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method,
kv_lora_rank=config.kv_lora_rank,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
)
for layer_id in range(config.n_layers)
]
@ -90,7 +83,7 @@ class Transformer(AutoModel):
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight is True:
if self.config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight
self._init_weights()
@ -106,7 +99,7 @@ class Transformer(AutoModel):
state_dict = dict(state_dict)
if self.config.tie_weight is True:
if self.config.tie_weight:
# same tensor for embed and lm_head
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
@ -122,7 +115,7 @@ class Transformer(AutoModel):
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight is True:
if self.config.tie_weight:
lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict:
del state_dict[lm_head_key]

View File

@ -1,5 +1,4 @@
import json
import time
from pathlib import Path
from typing import Any, Dict, Optional
@ -17,13 +16,11 @@ class Checkpoint:
epoch: int = 0,
iteration: int = 0,
extra: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.extra = extra or {}
self.meta = meta or {}
def save(
self,
@ -38,16 +35,13 @@ class Checkpoint:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.time(),
}
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra:
for key, value in self.extra.items():
torch.save(value, save_path / f"{key}.pt")
torch.save(self.extra, save_path / "extra.pt")
@classmethod
def load(
@ -70,14 +64,14 @@ class Checkpoint:
state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = {}
for f in save_path.iterdir():
if f.suffix == ".pt" and f.stem not in ("meta",):
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
extra = None
extra_path = save_path / "extra.pt"
if extra_path.exists():
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
return cls(
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
extra=extra or None,
extra=extra,
)

View File

@ -79,7 +79,8 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_begin(self, context: TrainContext):
def on_step_end(self, context: TrainContext):
_ = context
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -89,8 +90,6 @@ class CheckpointCallback(TrainCallback):
Checkpoint callback for trainer.
"""
extra_keys = ("optimizer", "scheduler")
def __init__(
self,
save_dir: str,
@ -98,14 +97,12 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.save_extra_fn = save_extra_fn
self.last_ckpt_iter = 0
@only_on_rank(0)
@ -119,22 +116,17 @@ class CheckpointCallback(TrainCallback):
else context.model.state_dict()
)
extra = self.save_extra_fn(context)
extra = self.save_extra_fn(context) if self.save_extra_fn else None
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
@ -146,21 +138,6 @@ class CheckpointCallback(TrainCallback):
def on_error(self, context: TrainContext):
self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Self
from typing import Callable, Optional, Self
import torch.nn as nn
from torch.optim import Optimizer
@ -21,7 +21,6 @@ class TrainContext:
optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
@ -36,9 +35,11 @@ class TrainContextBuilder:
def __init__(
self,
config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
@ -49,7 +50,6 @@ class TrainContextBuilder:
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
config=self.config,
)
device = get_current_device()
@ -71,8 +71,11 @@ class TrainContextBuilder:
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_per_device
sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
start_epoch=context.epoch,
@ -81,7 +84,7 @@ class TrainContextBuilder:
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_per_device,
batch_size=cfg.batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,

View File

@ -1,4 +1,5 @@
import logging
from itertools import batched
from typing import List, Optional
from astrai.config import TrainConfig
@ -32,6 +33,11 @@ class Trainer:
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
@ -39,47 +45,49 @@ class Trainer:
method(context)
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
config = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
backend=config.backend,
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
device_type=config.device_type,
checkpoint=checkpoint,
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
grad_accum_steps = cfg.grad_accum_steps
accumulation_steps = max(self.train_config.accumulation_steps, 1)
for epoch in range(context.epoch, cfg.n_epoch):
for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / grad_accum_steps
stand_loss.backward()
context.iteration += 1
self._call_callbacks("on_batch_end", context)
for steps in batched(context.dataloader, accumulation_steps):
self._call_callbacks("on_step_begin", context)
if context.iteration % grad_accum_steps == 0:
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
step_batch_nums = len(steps)
for batch in steps:
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
context.iteration += 1
if context.scheduler:
context.scheduler.step()
stand_loss = loss / step_batch_nums
stand_loss.backward()
self._call_callbacks("on_batch_end", context)
self._call_callbacks("on_step_end", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context)

View File

@ -11,6 +11,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
@ -21,15 +22,16 @@ def generate_text():
model=model,
tokenizer=tokenizer,
)
for token in engine.generate(
response = engine.generate(
prompt=query,
stream=True,
stream=False,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,
):
print(token, end="", flush=True)
)
print(response)
if __name__ == "__main__":

View File

@ -42,20 +42,18 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
parser.add_argument(
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
)
parser.add_argument(
"--grad_accum_steps",
"--accumulation_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.05,
help="Fraction of total steps used for LR warmup.",
"--warmup_steps",
type=int,
default=1000,
help="Number of warmup steps for LR scheduler.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -69,13 +67,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.95,
default=0.9,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.99,
default=0.95,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
@ -116,7 +114,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--label_smoothing",
type=float,
default=0.05,
default=0.1,
help="cross_entropy function label smoothing parameter",
)
@ -183,34 +181,17 @@ def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
def compute_total_steps(
dataset_len: int,
n_epoch: int,
batch_per_device: int,
nprocs: int,
grad_accum_steps: int,
) -> int:
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
return total_steps
def train(
train_type: str,
param_path: str,
data_root_path: str,
max_lr: float,
n_epoch: int,
batch_per_device: int,
batch_size: int,
start_epoch: int,
start_batch: int,
grad_accum_steps: int,
warmup_ratio: float,
accumulation_steps: int,
warmup_steps: int,
ckpt_interval: int,
ckpt_dir: str,
dpo_beta: float,
@ -235,8 +216,10 @@ def train(
assert os.path.exists(param_path)
# Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json")
config = ModelConfig.from_file(config_path)
if os.path.exists(config_path):
config.load(config_path)
if window_size is None:
window_size = config.max_len
@ -277,17 +260,13 @@ def train(
},
)
total_steps = compute_total_steps(
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
)
warmup_steps = int(warmup_ratio * total_steps)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": min(warmup_steps, total_steps),
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
"warmup_steps": warmup_steps,
"lr_decay_steps": total_steps - warmup_steps,
},
)
@ -299,11 +278,11 @@ def train(
scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_per_device=batch_per_device,
batch_size=batch_size,
start_epoch=start_epoch,
start_batch=start_batch,
ckpt_interval=ckpt_interval,
grad_accum_steps=grad_accum_steps,
accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=num_workers,

View File

@ -107,12 +107,12 @@ def test_model():
"""Session-scoped small Transformer model, created once."""
config = ModelConfig(
vocab_size=1000,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
dim=16,
n_heads=4,
n_kv_heads=2,
dim_ffn=32,
max_len=1024,
n_layers=4,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
json.dump(
{
"vocab_size": 1000,
"dim": 8,
"n_heads": 2,
"n_kv_heads": 1,
"dim_ffn": 16,
"max_len": 64,
"n_layers": 2,
"dim": 16,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 32,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
},
f,

View File

@ -35,33 +35,6 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30
def test_checkpoint_with_extra():
"""Verify extra keys are saved as individual .pt files and loaded back."""
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
extra = {
"optimizer": optimizer.state_dict(),
"scheduler": {"last_epoch": 5},
}
checkpoint = Checkpoint(
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
import os
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
loaded = Checkpoint.load(tmpdir)
assert loaded.extra["scheduler"]["last_epoch"] == 5
assert "state" in loaded.extra["optimizer"]
def simple_training():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)

View File

@ -10,7 +10,7 @@ from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
StorageFactory,
create_storage,
detect_format,
load_json,
save_h5,
@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env):
def test_create_storage_invalid_type():
"""StorageFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StorageFactory.create("parquet")
"""create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown storage type"):
create_storage("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):

View File

@ -1,108 +0,0 @@
import pytest
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
CONFIGS = [
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
id="gqa_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "mla",
"ffn_type": "mlp",
"kv_lora_rank": 4,
"qk_nope_head_dim": 2,
"qk_rope_head_dim": 2,
},
id="mla_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "moe",
"n_routed_experts": 4,
"n_shared_experts": 1,
"n_activated_experts": 2,
"moe_topk_method": "greedy",
},
id="gqa_moe",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "mlp",
"rope_theta": 100000.0,
},
id="gqa_rope_theta",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
id="gqa_qk_norm",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
id="gqa_tie_weight",
),
]
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert "logits" in output
assert "hidden_states" in output
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
assert not torch.isnan(output["logits"]).any()
assert not torch.isnan(output["hidden_states"]).any()
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert not torch.isnan(output["logits"]).any()

View File

@ -17,10 +17,10 @@ def transformer_test_env():
config = {
"vocab_size": 1000,
"dim": 8,
"n_heads": 2,
"n_kv_heads": 1,
"dim_ffn": 16,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig.from_file(config_path)
config = ModelConfig().load(config_path)
model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig.from_file(config_path)
config = ModelConfig().load(config_path)
model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig.from_file(config_path)
config = ModelConfig().load(config_path)
original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig.from_file(config_path)
loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
loaded_config = ModelConfig.from_file(config_path)
loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))

View File

@ -31,8 +31,8 @@ def create_train_config(
device: str,
strategy: str = "seq",
n_epoch: int = 1,
batch_per_device: int = 2,
grad_accum_steps: int = 1,
batch_size: int = 2,
accumulation_steps: int = 1,
max_grad_norm: float = 1.0,
ckpt_interval: int = 5,
random_seed: int = 42,
@ -47,8 +47,8 @@ def create_train_config(
device: Device type ("cuda" or "cpu")
strategy: Training strategy type (default: "seq")
n_epoch: Number of epochs (default: 1)
batch_per_device: Batch size per device (default: 2)
grad_accum_steps: Gradient accumulation steps (default: 1)
batch_size: Batch size (default: 2)
accumulation_steps: Gradient accumulation steps (default: 1)
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
ckpt_interval: Checkpoint save interval in iterations (default: 5)
random_seed: Random seed for reproducibility (default: 42)
@ -74,9 +74,9 @@ def create_train_config(
scheduler_fn=scheduler_fn,
ckpt_dir=test_dir,
n_epoch=n_epoch,
batch_per_device=batch_per_device,
batch_size=batch_size,
ckpt_interval=ckpt_interval,
grad_accum_steps=grad_accum_steps,
accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
device_type=device,

View File

@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_per_device=2,
batch_size=2,
ckpt_interval=3,
grad_accum_steps=1,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"],

View File

@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
n_epoch=2,
batch_per_device=2,
batch_size=2,
ckpt_interval=1,
grad_accum_steps=2,
accumulation_steps=2,
random_seed=np.random.randint(1e4),
device_type=base_test_env["device"],
)

View File

@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
"""Test training with different batch sizes"""
batch_sizes = [1, 2, 4, 8]
for batch_per_device in batch_sizes:
for batch_size in batch_sizes:
train_config = train_config_factory(
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_per_device=batch_per_device,
batch_size=batch_size,
)
assert train_config.batch_per_device == batch_per_device
assert train_config.batch_size == batch_size
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
"""Test training with different gradient accumulation steps"""
grad_accum_steps_list = [1, 2, 4]
accumulation_steps_list = [1, 2, 4]
for grad_accum_steps in grad_accum_steps_list:
for accumulation_steps in accumulation_steps_list:
train_config = train_config_factory(
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_per_device=2,
grad_accum_steps=grad_accum_steps,
batch_size=2,
accumulation_steps=accumulation_steps,
)
trainer = Trainer(train_config)
trainer.train()
assert train_config.grad_accum_steps == grad_accum_steps
assert train_config.accumulation_steps == accumulation_steps
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
"""Test training with memory-efficient configurations"""
# Test with smaller batch sizes and gradient checkpointing
small_batch_configs = [
{"batch_per_device": 1, "grad_accum_steps": 8},
{"batch_per_device": 2, "grad_accum_steps": 4},
{"batch_per_device": 4, "grad_accum_steps": 2},
{"batch_size": 1, "accumulation_steps": 8},
{"batch_size": 2, "accumulation_steps": 4},
{"batch_size": 4, "accumulation_steps": 2},
]
for config in small_batch_configs:
@ -54,9 +54,8 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_per_device=config["batch_per_device"],
grad_accum_steps=config["grad_accum_steps"],
batch_size=config["batch_size"],
accumulation_steps=config["accumulation_steps"],
)
assert train_config.grad_accum_steps == config["grad_accum_steps"]
assert train_config.batch_per_device == config["batch_per_device"]
assert train_config.accumulation_steps == config["accumulation_steps"]