Compare commits

..

10 Commits

Author SHA1 Message Date
ViperEkura ad9f4d9cf6 refactor: generate_ar 改用流式输出并去除冗余注释 2026-05-17 10:23:42 +08:00
ViperEkura e1638a7ade fix: 修正AdamW超参数默认值与文档示例
- 交换adamw_beta1/adamw_beta2默认值:beta1=0.95, beta2=0.99
- label_smoothing默认值改为0.05
- 文档示例统一更新:train_type=pt, weight_decay=0.01
- 移除文档中过时的strategy default标注
2026-05-16 22:46:17 +08:00
ViperEkura f91bfee33e refactor: Config序列化统一BaseConfig基类
- 新增astrai/config/base.py,提供to_dict/from_dict基类
- 统一命名:load/save → from_file/to_file
- Checkpoint.meta合并训练配置到meta.json
- sys.stderr.warn → warnings.warn
- from_file改为classmethod
2026-05-16 22:06:39 +08:00
ViperEkura d7a7f570ed refactor: 训练循环改为两重迭代并统一参数命名
- 训练循环从三重(epoch→batched→batch)改为二重(epoch→batch)
- batch_size → batch_per_device, accumulation_steps → grad_accum_steps
- scheduler 移入 step block 对齐 optimizer 更新步
- GradientClippingCallback 改用 on_step_begin 避免零梯度裁剪
- 移除 _train_impl 误导性的 -> Checkpoint 标注
- total_steps 修除为向下取整并精简为一行
- warmup_steps 改为 warmup_ratio (默认0.05)
2026-05-16 21:27:35 +08:00
ViperEkura 7dea929788 refactor: checkpoint 按 HF 方式存独立 .pt 文件,callback 接管恢复
- Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt
- meta.json 新增 timestamp
- CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性
- on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn
2026-05-16 18:29:04 +08:00
ViperEkura 026d1fc33d fix: total_steps 改用 ceiling 匹配实际步数
原公式全用 floor 少算 optimizer step,改用逐层 ceiling
(ceil_div via (a+b-1)//b)对齐 DDP sampler padding +
DataLoader drop_last=False 尾批 + batched 尾组截断。
2026-05-16 17:53:18 +08:00
ViperEkura 7242eedbf4 fix: 学习率调度按 optimizer step 计数并防止 warmup 越界
- total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率
- warmup_steps 用 min 截断,避免 lr_decay_steps 为负
2026-05-16 17:07:36 +08:00
ViperEkura 04c0dc7a47 refactor: Storage 改用工厂模式,server reload 接入 uvicorn
- 新增 StorageFactory(BaseFactory[BaseStorage]) 替代手写 dict 注册
- H5Storage / JSONStorage 通过 @StorageFactory.register 注册
- dataset.py 使用 StorageFactory.create() 替代 create_storage()
- 删除 create_storage / available_storage_types 死函数
- server.py reload 参数正式传入 uvicorn.run()
2026-05-16 17:00:26 +08:00
ViperEkura 48a53121ba refactor: 工厂 kwargs 过滤及组件参数清理
- BaseFactory.create() 按 __init__ 签名过滤多余 kwargs
- 移除 GQA/MLA/MLP/DeepSeekMoE 中多余的 **kwargs
- MLP/DeepSeekMoE 参数名统一为 dim_ffn
- scheduler max_seq_len 增加 None 显式判断
- 默认 max_prompt_len 提升至 2048
2026-05-16 16:47:41 +08:00
ViperEkura 0ba8c70ce1 fix: 修复 MLA 多个 bug 并缩小测试模型参数
- MLA kv_b_proj 输出维度和 q_rope 切分偏移修复
- 打通 MLA 配置从 ModelConfig 到 DecoderBlock 的传递路径
- rope_theta 配置不再被忽略,MLA 使用 qk_rope_head_dim
- tie_weight 使用 is True 避免 None 隐式生效
- norm_eps/rope base 类型标注修正
- 测试模型参数缩小 (dim=8, head_dim=4)
- 新增 6 种架构配置 × 2 场景的前向传播测试
2026-05-16 14:57:43 +08:00
35 changed files with 628 additions and 419 deletions

View File

@ -78,15 +78,27 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
#### Train a Model
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
Full reference at [Parameter Guide](assets/docs/params.md).

View File

@ -84,15 +84,27 @@ python scripts/demo/download.py
#### 训练模型
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
完整参数列表见[参数说明](./params.md)。

View File

@ -5,10 +5,15 @@
```mermaid
classDiagram
namespace config {
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
}
class BaseModelConfig {
+Optional[str] model_type
+load(config_path) Self
+save(config_path)
+from_file(config_path) Self
+to_file(config_path)
}
class ModelConfig {
@ -30,6 +35,9 @@ classDiagram
+int n_shared_experts
+int n_activated_experts
+str moe_topk_method
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+load(config_path) ModelConfig
+save(config_path)
}
@ -41,8 +49,8 @@ classDiagram
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_size
+int accumulation_steps
+int batch_per_device
+int grad_accum_steps
+float max_grad_norm
+int start_epoch
+int start_batch
@ -69,7 +77,7 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+BaseStorage storage
+Optional[BaseStorage] storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
@ -126,8 +134,8 @@ classDiagram
}
class ResumableDistributedSampler {
+int start_epoch
+int start_iter
+int epoch
+int iter
}
class DatasetFactory {
@ -144,6 +152,7 @@ classDiagram
+int epoch
+int iteration
+dict extra
+dict meta
+save(save_dir)
+load(save_dir) Checkpoint
}
@ -155,7 +164,7 @@ classDiagram
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module
+from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
@ -167,7 +176,7 @@ classDiagram
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict)
+state_dict()
}
@ -185,6 +194,7 @@ classDiagram
+int n_kv_heads
+int head_dim
+int n_rep
+int layer_id
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj
@ -201,6 +211,7 @@ classDiagram
+int qk_nope_head_dim
+int qk_rope_head_dim
+int n_rep
+int layer_id
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
@ -215,6 +226,7 @@ classDiagram
}
class DeepSeekMoE {
+int dim
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
@ -236,6 +248,7 @@ classDiagram
class RMSNorm {
+Parameter weight
+float norm_eps
+tuple normalized_shape
+forward(x) Tensor
}
@ -299,7 +312,6 @@ classDiagram
+TrainConfig train_config
+List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback]
}
@ -324,7 +336,7 @@ classDiagram
}
class BaseStrategy {
+nn.Module model
+Union[Callable, nn.Module] model
+str device
+compute_loss(batch) Tensor
}
@ -332,7 +344,7 @@ classDiagram
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(model, train_type, device, **kwargs) BaseStrategy
+create(train_type, model, device, **kwargs) BaseStrategy
}
class SEQStrategy {
@ -400,7 +412,7 @@ classDiagram
class GradientClippingCallback {
+float max_grad_norm
+on_step_end(context)
+on_step_begin(context)
}
class CheckpointCallback {
@ -459,10 +471,7 @@ classDiagram
+TaskManager _task_mgr
+bool _running
+Thread _loop_thread
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
@ -500,10 +509,7 @@ classDiagram
}
class Storage {
+int n_layers
+int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache
+Tensor v_cache
+write(layer_id, page_table, start_pos, k, v)
@ -675,7 +681,6 @@ classDiagram
}
class AnthropicHandler {
+List[str] stop_sequences
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
@ -704,7 +709,7 @@ classDiagram
namespace parallel {
class Functions {
+spawn_parallel_fn(fn, nprocs)
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
@ -751,6 +756,8 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- ModelConfig
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
@ -839,7 +846,7 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
@ -878,4 +885,4 @@ classDiagram
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-16

View File

@ -10,14 +10,14 @@
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
| `--batch_per_device` | Batch size per device | 1 |
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 |
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
@ -25,8 +25,8 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_beta1` | AdamW beta1 | 0.95 |
| `--adamw_beta2` | AdamW beta2 | 0.99 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
@ -60,7 +60,7 @@
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
@ -69,90 +69,29 @@
### Usage Example
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
---
## Generation Parameters
### GenerationRequest Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `messages` | List of message dictionaries (role, content) | required |
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
| `stream` | Whether to stream output | False |
### Usage Example
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
# Build request with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_tokens=None,
)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
# Or use simple generate interface
result = engine.generate(
prompt="Hello",
stream=False,
max_tokens=1024,
temperature=0.8,
top_p=0.95,
top_k=50,
)
```
### Generation Modes
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-16

View File

@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
## Training Loop
Nested loop: **epoch****step** (accumulation window) → **batch**.
Two-level loop: **epoch****batch**. Optimizer step fires every `grad_accum_steps` batches.
```
on_train_begin
on_epoch_begin
for steps in batched(dataloader, accumulation_steps):
on_step_begin
step_batch_nums = len(steps)
for batch in steps:
on_batch_begin
loss = strategy(batch)
(loss / step_batch_nums).backward()
iteration += 1
on_batch_end
on_step_end
optimizer.step()
optimizer.zero_grad()
scheduler.step()
for batch in dataloader:
on_batch_begin
loss = strategy(batch)
(loss / grad_accum_steps).backward()
iteration += 1
on_batch_end
if iteration % grad_accum_steps == 0:
on_step_begin
optimizer.step()
optimizer.zero_grad()
on_step_end
scheduler.step()
on_epoch_end
on_train_end
```
@ -91,9 +91,9 @@ on_train_end
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
@ -157,12 +157,13 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra)
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
└── load(save_dir) broadcasts metadata from rank-0
```
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern)
@ -183,17 +184,29 @@ context = (
## Training CLI
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/data \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-16

77
astrai/config/base.py Normal file
View File

@ -0,0 +1,77 @@
import json
from dataclasses import MISSING, dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, dict):
try:
json.dumps(v)
d[fld.name] = v
except (TypeError, ValueError):
pass
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError

View File

@ -1,12 +1,14 @@
import json
import sys
import warnings
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
@dataclass
class BaseModelConfig:
"""Field-aware JSON load/save for dataclass configs.
class BaseModelConfig(BaseConfig):
"""Field-aware JSON from/to file for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
@ -14,76 +16,25 @@ class BaseModelConfig:
model_type: Optional[str] = None
def load(self, config_path: str) -> Self:
raw: Dict[str, Any] = {}
@classmethod
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw.update(json.load(f))
raw: Dict[str, Any] = json.load(f)
hints = get_type_hints(type(self))
valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
if key not in valid:
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
continue
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
return cls.from_dict(raw)
try:
value = self._coerce(value, target_type)
except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass
class ModelConfig(BaseModelConfig):
@ -106,6 +57,11 @@ class ModelConfig(BaseModelConfig):
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
# MoE
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None

View File

@ -6,9 +6,11 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
@dataclass
class TrainConfig:
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})
@ -20,8 +22,10 @@ class TrainConfig:
default=None, metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field(
batch_per_device: int = field(
default=4, metadata={"help": "Batch size per device."}
)
grad_accum_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(

View File

@ -9,8 +9,7 @@ from astrai.dataset.storage import (
H5Storage,
JSONStorage,
MultiSegmentFetcher,
available_storage_types,
create_storage,
StorageFactory,
detect_format,
load_h5,
load_json,
@ -26,9 +25,8 @@ __all__ = [
"BaseStorage",
"H5Storage",
"JSONStorage",
"create_storage",
"StorageFactory",
"detect_format",
"available_storage_types",
"save_h5",
"load_h5",
"save_json",

View File

@ -9,7 +9,7 @@ from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
create_storage,
StorageFactory,
detect_format,
)
from astrai.factory import BaseFactory
@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC):
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = create_storage(storage_type)
self.storage = StorageFactory.create(storage_type)
self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None):

View File

@ -15,6 +15,8 @@ import h5py
import torch
from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
@ -258,6 +260,24 @@ class BaseStorage(ABC):
return self._fetcher.multi_keys
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@StorageFactory.register("h5")
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
@ -266,6 +286,7 @@ class H5Storage(BaseStorage):
self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("json")
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
@ -278,35 +299,3 @@ class JSONStorage(BaseStorage):
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())

View File

@ -1,5 +1,6 @@
"""Base factory class for extensible component registration."""
import inspect
from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]):
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
@ -139,6 +144,17 @@ class BaseFactory(ABC, Generic[T]):
f"Supported types: {sorted(cls._registry.list_names())}"
)
component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs)
@classmethod

View File

@ -163,4 +163,5 @@ def run_server(
app,
host=host,
port=port,
reload=reload,
)

View File

@ -22,14 +22,22 @@ class InferenceScheduler:
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
max_prompt_len: int = 2048,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
self.max_seq_len = max_seq_len or config.max_len
if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype

View File

@ -60,10 +60,9 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path)
# Load config
config = ModelConfig()
config_path = model_path / "config.json"
if config_path.exists():
config.load(str(config_path))
config = ModelConfig.from_file(str(config_path))
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
@ -89,7 +88,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
save_path.mkdir(parents=True, exist_ok=True)
# Save config
self.config.save(str(save_path / "config.json"))
self.config.to_file(str(save_path / "config.json"))
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))

View File

@ -40,7 +40,6 @@ class GQA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
assert dim % n_heads == 0
@ -123,7 +122,6 @@ class MLA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
self.dim = dim
@ -143,7 +141,7 @@ class MLA(nn.Module):
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
n_kv_heads * (2 * self.head_dim),
)
self.o_proj = Linear(dim, dim, bias=False)
@ -176,7 +174,7 @@ class MLA(nn.Module):
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_rope_head_dim :],
q[..., self.qk_nope_head_dim :],
)
q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb)

View File

@ -16,13 +16,13 @@ class DecoderBlock(nn.Module):
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**moe_kwargs,
**kwargs,
):
super().__init__()
self.attention = AttnFactory.create(
@ -34,10 +34,11 @@ class DecoderBlock(nn.Module):
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
**kwargs,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
def forward(
self,

View File

@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
self.up = Linear(dim, dim_ffn)
self.gate = Linear(dim, dim_ffn)
self.down = Linear(dim_ffn, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
@ -32,12 +32,11 @@ class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_feed_forward: int,
dim_ffn: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
**kwargs,
):
super().__init__()
self.dim = dim
@ -49,10 +48,10 @@ class DeepSeekMoE(nn.Module):
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:

View File

@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
def __init__(self, dim: int, max_len: int, base: float = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len

View File

@ -53,9 +53,13 @@ class Transformer(AutoModel):
def __init__(self, config: ModelConfig):
super().__init__(config)
self.config = config
self.rotary_embedding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
rope_dim = (
config.qk_rope_head_dim
if config.attn_type == "mla"
else config.dim // config.n_heads
)
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
@ -75,6 +79,9 @@ class Transformer(AutoModel):
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method,
kv_lora_rank=config.kv_lora_rank,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
)
for layer_id in range(config.n_layers)
]
@ -83,7 +90,7 @@ class Transformer(AutoModel):
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight:
if self.config.tie_weight is True:
self.lm_head.weight = self.embed_tokens.weight
self._init_weights()
@ -99,7 +106,7 @@ class Transformer(AutoModel):
state_dict = dict(state_dict)
if self.config.tie_weight:
if self.config.tie_weight is True:
# same tensor for embed and lm_head
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
@ -115,7 +122,7 @@ class Transformer(AutoModel):
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight:
if self.config.tie_weight is True:
lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict:
del state_dict[lm_head_key]

View File

@ -1,4 +1,5 @@
import json
import time
from pathlib import Path
from typing import Any, Dict, Optional
@ -16,11 +17,13 @@ class Checkpoint:
epoch: int = 0,
iteration: int = 0,
extra: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.extra = extra or {}
self.meta = meta or {}
def save(
self,
@ -35,13 +38,16 @@ class Checkpoint:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.time(),
}
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra:
torch.save(self.extra, save_path / "extra.pt")
for key, value in self.extra.items():
torch.save(value, save_path / f"{key}.pt")
@classmethod
def load(
@ -64,14 +70,14 @@ class Checkpoint:
state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = None
extra_path = save_path / "extra.pt"
if extra_path.exists():
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
extra = {}
for f in save_path.iterdir():
if f.suffix == ".pt" and f.stem not in ("meta",):
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
return cls(
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
extra=extra,
extra=extra or None,
)

View File

@ -79,8 +79,7 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_end(self, context: TrainContext):
_ = context
def on_step_begin(self, context: TrainContext):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -90,6 +89,8 @@ class CheckpointCallback(TrainCallback):
Checkpoint callback for trainer.
"""
extra_keys = ("optimizer", "scheduler")
def __init__(
self,
save_dir: str,
@ -97,12 +98,14 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0
@only_on_rank(0)
@ -116,17 +119,22 @@ class CheckpointCallback(TrainCallback):
else context.model.state_dict()
)
extra = self.save_extra_fn(context) if self.save_extra_fn else None
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback):
def on_error(self, context: TrainContext):
self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable, Optional, Self
from typing import Optional, Self
import torch.nn as nn
from torch.optim import Optimizer
@ -21,6 +21,7 @@ class TrainContext:
optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
@ -35,11 +36,9 @@ class TrainContextBuilder:
def __init__(
self,
config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
@ -50,6 +49,7 @@ class TrainContextBuilder:
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
config=self.config,
)
device = get_current_device()
@ -71,11 +71,8 @@ class TrainContextBuilder:
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
start_epoch=context.epoch,
@ -84,7 +81,7 @@ class TrainContextBuilder:
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
batch_size=cfg.batch_per_device,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,

View File

@ -1,5 +1,4 @@
import logging
from itertools import batched
from typing import List, Optional
from astrai.config import TrainConfig
@ -33,11 +32,6 @@ class Trainer:
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
@ -45,49 +39,47 @@ class Trainer:
method(context)
def train(self, checkpoint: Optional[Checkpoint] = None):
config = self.train_config
cfg = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=config.backend,
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
device_type=config.device_type,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
checkpoint=checkpoint,
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
accumulation_steps = max(self.train_config.accumulation_steps, 1)
grad_accum_steps = cfg.grad_accum_steps
for epoch in range(context.epoch, self.train_config.n_epoch):
for epoch in range(context.epoch, cfg.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for steps in batched(context.dataloader, accumulation_steps):
self._call_callbacks("on_step_begin", context)
for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / grad_accum_steps
stand_loss.backward()
context.iteration += 1
self._call_callbacks("on_batch_end", context)
step_batch_nums = len(steps)
for batch in steps:
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
context.iteration += 1
if context.iteration % grad_accum_steps == 0:
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
stand_loss = loss / step_batch_nums
stand_loss.backward()
self._call_callbacks("on_batch_end", context)
self._call_callbacks("on_step_end", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context)

View File

@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
@ -22,16 +21,15 @@ def generate_text():
model=model,
tokenizer=tokenizer,
)
response = engine.generate(
for token in engine.generate(
prompt=query,
stream=False,
stream=True,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,
)
print(response)
):
print(token, end="", flush=True)
if __name__ == "__main__":

View File

@ -42,18 +42,20 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
parser.add_argument(
"--accumulation_steps",
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
)
parser.add_argument(
"--grad_accum_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=1000,
help="Number of warmup steps for LR scheduler.",
"--warmup_ratio",
type=float,
default=0.05,
help="Fraction of total steps used for LR warmup.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -67,13 +69,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
default=0.95,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
default=0.99,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
@ -114,7 +116,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--label_smoothing",
type=float,
default=0.1,
default=0.05,
help="cross_entropy function label smoothing parameter",
)
@ -181,17 +183,34 @@ def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
def compute_total_steps(
dataset_len: int,
n_epoch: int,
batch_per_device: int,
nprocs: int,
grad_accum_steps: int,
) -> int:
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
return total_steps
def train(
train_type: str,
param_path: str,
data_root_path: str,
max_lr: float,
n_epoch: int,
batch_size: int,
batch_per_device: int,
start_epoch: int,
start_batch: int,
accumulation_steps: int,
warmup_steps: int,
grad_accum_steps: int,
warmup_ratio: float,
ckpt_interval: int,
ckpt_dir: str,
dpo_beta: float,
@ -216,10 +235,8 @@ def train(
assert os.path.exists(param_path)
# Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json")
if os.path.exists(config_path):
config.load(config_path)
config = ModelConfig.from_file(config_path)
if window_size is None:
window_size = config.max_len
@ -260,13 +277,17 @@ def train(
},
)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
total_steps = compute_total_steps(
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
)
warmup_steps = int(warmup_ratio * total_steps)
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": total_steps - warmup_steps,
"warmup_steps": min(warmup_steps, total_steps),
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
},
)
@ -278,11 +299,11 @@ def train(
scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_size=batch_size,
batch_per_device=batch_per_device,
start_epoch=start_epoch,
start_batch=start_batch,
ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps,
grad_accum_steps=grad_accum_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=num_workers,

View File

@ -107,12 +107,12 @@ def test_model():
"""Session-scoped small Transformer model, created once."""
config = ModelConfig(
vocab_size=1000,
dim=16,
n_heads=4,
n_kv_heads=2,
dim_ffn=32,
max_len=1024,
n_layers=4,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
json.dump(
{
"vocab_size": 1000,
"dim": 16,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 32,
"max_len": 1024,
"n_layers": 4,
"dim": 8,
"n_heads": 2,
"n_kv_heads": 1,
"dim_ffn": 16,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
},
f,

View File

@ -35,6 +35,33 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30
def test_checkpoint_with_extra():
"""Verify extra keys are saved as individual .pt files and loaded back."""
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
extra = {
"optimizer": optimizer.state_dict(),
"scheduler": {"last_epoch": 5},
}
checkpoint = Checkpoint(
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
import os
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
loaded = Checkpoint.load(tmpdir)
assert loaded.extra["scheduler"]["last_epoch"] == 5
assert "state" in loaded.extra["optimizer"]
def simple_training():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)

View File

@ -10,7 +10,7 @@ from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
create_storage,
StorageFactory,
detect_format,
load_json,
save_h5,
@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env):
def test_create_storage_invalid_type():
"""create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown storage type"):
create_storage("parquet")
"""StorageFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StorageFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):

View File

@ -0,0 +1,108 @@
import pytest
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
CONFIGS = [
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
id="gqa_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "mla",
"ffn_type": "mlp",
"kv_lora_rank": 4,
"qk_nope_head_dim": 2,
"qk_rope_head_dim": 2,
},
id="mla_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "moe",
"n_routed_experts": 4,
"n_shared_experts": 1,
"n_activated_experts": 2,
"moe_topk_method": "greedy",
},
id="gqa_moe",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "mlp",
"rope_theta": 100000.0,
},
id="gqa_rope_theta",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
id="gqa_qk_norm",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
id="gqa_tie_weight",
),
]
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert "logits" in output
assert "hidden_states" in output
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
assert not torch.isnan(output["logits"]).any()
assert not torch.isnan(output["hidden_states"]).any()
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert not torch.isnan(output["logits"]).any()

View File

@ -17,10 +17,10 @@ def transformer_test_env():
config = {
"vocab_size": 1000,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"dim": 8,
"n_heads": 2,
"n_kv_heads": 1,
"dim_ffn": 16,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
config = ModelConfig.from_file(config_path)
original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig().load(config_path)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
loaded_config = ModelConfig().load(config_path)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))

View File

@ -31,8 +31,8 @@ def create_train_config(
device: str,
strategy: str = "seq",
n_epoch: int = 1,
batch_size: int = 2,
accumulation_steps: int = 1,
batch_per_device: int = 2,
grad_accum_steps: int = 1,
max_grad_norm: float = 1.0,
ckpt_interval: int = 5,
random_seed: int = 42,
@ -47,8 +47,8 @@ def create_train_config(
device: Device type ("cuda" or "cpu")
strategy: Training strategy type (default: "seq")
n_epoch: Number of epochs (default: 1)
batch_size: Batch size (default: 2)
accumulation_steps: Gradient accumulation steps (default: 1)
batch_per_device: Batch size per device (default: 2)
grad_accum_steps: Gradient accumulation steps (default: 1)
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
ckpt_interval: Checkpoint save interval in iterations (default: 5)
random_seed: Random seed for reproducibility (default: 42)
@ -74,9 +74,9 @@ def create_train_config(
scheduler_fn=scheduler_fn,
ckpt_dir=test_dir,
n_epoch=n_epoch,
batch_size=batch_size,
batch_per_device=batch_per_device,
ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps,
grad_accum_steps=grad_accum_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
device_type=device,

View File

@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=2,
batch_per_device=2,
ckpt_interval=3,
accumulation_steps=1,
grad_accum_steps=1,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"],

View File

@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
n_epoch=2,
batch_size=2,
batch_per_device=2,
ckpt_interval=1,
accumulation_steps=2,
grad_accum_steps=2,
random_seed=np.random.randint(1e4),
device_type=base_test_env["device"],
)

View File

@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
"""Test training with different batch sizes"""
batch_sizes = [1, 2, 4, 8]
for batch_size in batch_sizes:
for batch_per_device in batch_sizes:
train_config = train_config_factory(
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_size=batch_size,
batch_per_device=batch_per_device,
)
assert train_config.batch_size == batch_size
assert train_config.batch_per_device == batch_per_device
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
"""Test training with different gradient accumulation steps"""
accumulation_steps_list = [1, 2, 4]
grad_accum_steps_list = [1, 2, 4]
for accumulation_steps in accumulation_steps_list:
for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory(
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_size=2,
accumulation_steps=accumulation_steps,
batch_per_device=2,
grad_accum_steps=grad_accum_steps,
)
trainer = Trainer(train_config)
trainer.train()
assert train_config.accumulation_steps == accumulation_steps
assert train_config.grad_accum_steps == grad_accum_steps
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
"""Test training with memory-efficient configurations"""
# Test with smaller batch sizes and gradient checkpointing
small_batch_configs = [
{"batch_size": 1, "accumulation_steps": 8},
{"batch_size": 2, "accumulation_steps": 4},
{"batch_size": 4, "accumulation_steps": 2},
{"batch_per_device": 1, "grad_accum_steps": 8},
{"batch_per_device": 2, "grad_accum_steps": 4},
{"batch_per_device": 4, "grad_accum_steps": 2},
]
for config in small_batch_configs:
@ -54,8 +54,9 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_size=config["batch_size"],
accumulation_steps=config["accumulation_steps"],
batch_per_device=config["batch_per_device"],
grad_accum_steps=config["grad_accum_steps"],
)
assert train_config.accumulation_steps == config["accumulation_steps"]
assert train_config.grad_accum_steps == config["grad_accum_steps"]
assert train_config.batch_per_device == config["batch_per_device"]