Compare commits
10 Commits
3d12a03909
...
ad9f4d9cf6
| Author | SHA1 | Date |
|---|---|---|
|
|
ad9f4d9cf6 | |
|
|
e1638a7ade | |
|
|
f91bfee33e | |
|
|
d7a7f570ed | |
|
|
7dea929788 | |
|
|
026d1fc33d | |
|
|
7242eedbf4 | |
|
|
04c0dc7a47 | |
|
|
48a53121ba | |
|
|
0ba8c70ce1 |
30
README.md
30
README.md
|
|
@ -78,15 +78,27 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
|
|||
#### Train a Model
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||
|
|
|
|||
|
|
@ -84,15 +84,27 @@ python scripts/demo/download.py
|
|||
#### 训练模型
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
完整参数列表见[参数说明](./params.md)。
|
||||
|
|
|
|||
|
|
@ -5,10 +5,15 @@
|
|||
```mermaid
|
||||
classDiagram
|
||||
namespace config {
|
||||
class BaseConfig {
|
||||
+to_dict() Dict
|
||||
+from_dict(d) Self
|
||||
}
|
||||
|
||||
class BaseModelConfig {
|
||||
+Optional[str] model_type
|
||||
+load(config_path) Self
|
||||
+save(config_path)
|
||||
+from_file(config_path) Self
|
||||
+to_file(config_path)
|
||||
}
|
||||
|
||||
class ModelConfig {
|
||||
|
|
@ -30,6 +35,9 @@ classDiagram
|
|||
+int n_shared_experts
|
||||
+int n_activated_experts
|
||||
+str moe_topk_method
|
||||
+Optional[int] kv_lora_rank
|
||||
+Optional[int] qk_nope_head_dim
|
||||
+Optional[int] qk_rope_head_dim
|
||||
+load(config_path) ModelConfig
|
||||
+save(config_path)
|
||||
}
|
||||
|
|
@ -41,8 +49,8 @@ classDiagram
|
|||
+Callable optimizer_fn
|
||||
+Callable scheduler_fn
|
||||
+int n_epoch
|
||||
+int batch_size
|
||||
+int accumulation_steps
|
||||
+int batch_per_device
|
||||
+int grad_accum_steps
|
||||
+float max_grad_norm
|
||||
+int start_epoch
|
||||
+int start_batch
|
||||
|
|
@ -69,7 +77,7 @@ classDiagram
|
|||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+BaseStorage storage
|
||||
+Optional[BaseStorage] storage
|
||||
+load(load_path, storage_type, tokenizer)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
|
|
@ -126,8 +134,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class ResumableDistributedSampler {
|
||||
+int start_epoch
|
||||
+int start_iter
|
||||
+int epoch
|
||||
+int iter
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
|
|
@ -144,6 +152,7 @@ classDiagram
|
|||
+int epoch
|
||||
+int iteration
|
||||
+dict extra
|
||||
+dict meta
|
||||
+save(save_dir)
|
||||
+load(save_dir) Checkpoint
|
||||
}
|
||||
|
|
@ -155,7 +164,7 @@ classDiagram
|
|||
+Registry _registry
|
||||
+register(model_type) decorator
|
||||
+get_component_class(model_type) Type
|
||||
+from_pretrained(path, disable_random_init) nn.Module
|
||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||
+save_pretrained(save_directory)
|
||||
+to(*args, **kwargs) Self
|
||||
}
|
||||
|
|
@ -167,7 +176,7 @@ classDiagram
|
|||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
+Linear lm_head
|
||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
|
||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
||||
+load_state_dict(state_dict)
|
||||
+state_dict()
|
||||
}
|
||||
|
|
@ -185,6 +194,7 @@ classDiagram
|
|||
+int n_kv_heads
|
||||
+int head_dim
|
||||
+int n_rep
|
||||
+int layer_id
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Linear q_proj, k_proj, v_proj, o_proj
|
||||
|
|
@ -201,6 +211,7 @@ classDiagram
|
|||
+int qk_nope_head_dim
|
||||
+int qk_rope_head_dim
|
||||
+int n_rep
|
||||
+int layer_id
|
||||
+bool use_gated_attention
|
||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||
+Linear o_proj
|
||||
|
|
@ -215,6 +226,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class DeepSeekMoE {
|
||||
+int dim
|
||||
+int n_routed_experts
|
||||
+int n_shared_experts
|
||||
+int n_activated_experts
|
||||
|
|
@ -236,6 +248,7 @@ classDiagram
|
|||
class RMSNorm {
|
||||
+Parameter weight
|
||||
+float norm_eps
|
||||
+tuple normalized_shape
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
|
|
@ -299,7 +312,6 @@ classDiagram
|
|||
+TrainConfig train_config
|
||||
+List[TrainCallback] callbacks
|
||||
+train(checkpoint)
|
||||
+_build_context(checkpoint) TrainContext
|
||||
+_get_default_callbacks() List[TrainCallback]
|
||||
}
|
||||
|
||||
|
|
@ -324,7 +336,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class BaseStrategy {
|
||||
+nn.Module model
|
||||
+Union[Callable, nn.Module] model
|
||||
+str device
|
||||
+compute_loss(batch) Tensor
|
||||
}
|
||||
|
|
@ -332,7 +344,7 @@ classDiagram
|
|||
class StrategyFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(model, train_type, device, **kwargs) BaseStrategy
|
||||
+create(train_type, model, device, **kwargs) BaseStrategy
|
||||
}
|
||||
|
||||
class SEQStrategy {
|
||||
|
|
@ -400,7 +412,7 @@ classDiagram
|
|||
|
||||
class GradientClippingCallback {
|
||||
+float max_grad_norm
|
||||
+on_step_end(context)
|
||||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class CheckpointCallback {
|
||||
|
|
@ -459,10 +471,7 @@ classDiagram
|
|||
+TaskManager _task_mgr
|
||||
+bool _running
|
||||
+Thread _loop_thread
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
+int page_size
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
+remove_task(task_id)
|
||||
+start()
|
||||
|
|
@ -500,10 +509,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class Storage {
|
||||
+int n_layers
|
||||
+int page_size
|
||||
+int head_dim
|
||||
+int n_kv_heads
|
||||
+Tensor k_cache
|
||||
+Tensor v_cache
|
||||
+write(layer_id, page_table, start_pos, k, v)
|
||||
|
|
@ -675,7 +681,6 @@ classDiagram
|
|||
}
|
||||
|
||||
class AnthropicHandler {
|
||||
+List[str] stop_sequences
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
|
|
@ -704,7 +709,7 @@ classDiagram
|
|||
|
||||
namespace parallel {
|
||||
class Functions {
|
||||
+spawn_parallel_fn(fn, nprocs)
|
||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
+get_current_device() str
|
||||
+get_world_size() int
|
||||
|
|
@ -751,6 +756,8 @@ classDiagram
|
|||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoModel <|-- Transformer
|
||||
BaseConfig <|-- BaseModelConfig
|
||||
BaseConfig <|-- TrainConfig
|
||||
BaseModelConfig <|-- ModelConfig
|
||||
BaseFactory <|-- AutoModel
|
||||
BaseFactory <|-- AttnFactory
|
||||
|
|
@ -839,7 +846,7 @@ classDiagram
|
|||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
|
|
@ -878,4 +885,4 @@ classDiagram
|
|||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@
|
|||
| `--data_root_path` | Dataset root directory | required |
|
||||
| `--param_path` | Model parameters or checkpoint path | required |
|
||||
| `--n_epoch` | Total training epochs | 1 |
|
||||
| `--batch_size` | Batch size | 1 |
|
||||
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||
| `--batch_per_device` | Batch size per device | 1 |
|
||||
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||
|
||||
### Learning Rate Scheduling
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--warmup_steps` | Warmup steps | 1000 |
|
||||
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||
|
||||
|
|
@ -25,8 +25,8 @@
|
|||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.95 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.99 |
|
||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||
|
||||
### Data Loading
|
||||
|
|
@ -60,7 +60,7 @@
|
|||
| Parameter | Description | Default | Used by |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` |
|
||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||
|
|
@ -69,90 +69,29 @@
|
|||
### Usage Example
|
||||
|
||||
```bash
|
||||
python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--n_epoch 3 \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 2000 \
|
||||
--max_grad_norm 1.0 \
|
||||
--ckpt_interval 5000 \
|
||||
--ckpt_dir ./checkpoints \
|
||||
--num_workers 4 \
|
||||
--nprocs 1 \
|
||||
--device_type cuda
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Generation Parameters
|
||||
|
||||
### GenerationRequest Parameters
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `messages` | List of message dictionaries (role, content) | required |
|
||||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||
| `top_k` | Top-k sampling count | 50 |
|
||||
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
|
||||
| `stream` | Whether to stream output | False |
|
||||
|
||||
### Usage Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.inference import InferenceEngine, GenerationRequest
|
||||
|
||||
# Load model using AutoModel
|
||||
model = AutoModel.from_pretrained("your_model_dir")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
||||
|
||||
# Create engine with separate model and tokenizer
|
||||
engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Build request with messages format
|
||||
request = GenerationRequest(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
# Generate (streaming)
|
||||
for token in engine.generate_with_request(request):
|
||||
print(token, end="", flush=True)
|
||||
|
||||
# Or use simple generate interface
|
||||
result = engine.generate(
|
||||
prompt="Hello",
|
||||
stream=False,
|
||||
max_tokens=1024,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
)
|
||||
```
|
||||
|
||||
### Generation Modes
|
||||
|
||||
| Mode | Description |
|
||||
|------|-------------|
|
||||
| `stream=True` | Streaming output, yields token by token |
|
||||
| `stream=False` | Non-streaming output, returns complete result |
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
|
|||
|
||||
## Training Loop
|
||||
|
||||
Nested loop: **epoch** → **step** (accumulation window) → **batch**.
|
||||
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
||||
|
||||
```
|
||||
on_train_begin
|
||||
on_epoch_begin
|
||||
for steps in batched(dataloader, accumulation_steps):
|
||||
on_step_begin
|
||||
step_batch_nums = len(steps)
|
||||
for batch in steps:
|
||||
on_batch_begin
|
||||
loss = strategy(batch)
|
||||
(loss / step_batch_nums).backward()
|
||||
iteration += 1
|
||||
on_batch_end
|
||||
on_step_end
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
for batch in dataloader:
|
||||
on_batch_begin
|
||||
loss = strategy(batch)
|
||||
(loss / grad_accum_steps).backward()
|
||||
iteration += 1
|
||||
on_batch_end
|
||||
|
||||
if iteration % grad_accum_steps == 0:
|
||||
on_step_begin
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
on_step_end
|
||||
scheduler.step()
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
|
@ -91,9 +91,9 @@ on_train_end
|
|||
|
||||
| Hook | Fires | Default callback |
|
||||
|------|-------|-----------------|
|
||||
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
|
||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
|
||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||
|
||||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||||
|
||||
|
|
@ -157,12 +157,13 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
|||
## Checkpoint
|
||||
|
||||
```
|
||||
Checkpoint(state_dict, epoch, iteration, extra)
|
||||
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
|
||||
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
||||
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
|
||||
└── load(save_dir) broadcasts metadata from rank-0
|
||||
```
|
||||
|
||||
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
|
||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
|
||||
|
||||
## TrainContextBuilder (Builder Pattern)
|
||||
|
||||
|
|
@ -183,17 +184,29 @@ context = (
|
|||
## Training CLI
|
||||
|
||||
```bash
|
||||
python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/data \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.05 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
Full parameter reference at [params.md](params.md).
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
import json
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {}
|
||||
for fld in fields(self):
|
||||
v = getattr(self, fld.name)
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
d[fld.name] = v
|
||||
elif v is None:
|
||||
d[fld.name] = None
|
||||
elif isinstance(v, dict):
|
||||
try:
|
||||
json.dumps(v)
|
||||
d[fld.name] = v
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> Self:
|
||||
hints = get_type_hints(cls)
|
||||
inst = cls.__new__(cls)
|
||||
for fld in fields(cls):
|
||||
if fld.name in d:
|
||||
v = d[fld.name]
|
||||
target = cls._unwrap_optional(hints.get(fld.name))
|
||||
if target is not None:
|
||||
try:
|
||||
v = cls._coerce(v, target)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
object.__setattr__(inst, fld.name, v)
|
||||
elif fld.default is not MISSING:
|
||||
object.__setattr__(inst, fld.name, fld.default)
|
||||
elif fld.default_factory is not MISSING:
|
||||
object.__setattr__(inst, fld.name, fld.default_factory())
|
||||
else:
|
||||
object.__setattr__(inst, fld.name, None)
|
||||
return inst
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_optional(tp) -> Optional[type]:
|
||||
if tp is None:
|
||||
return None
|
||||
origin = getattr(tp, "__origin__", None)
|
||||
if origin is not None:
|
||||
args = getattr(tp, "__args__", ())
|
||||
non_none = [a for a in args if a is not type(None)]
|
||||
return non_none[0] if non_none else None
|
||||
return tp
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value: Any, target_type: type) -> Any:
|
||||
if target_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if (
|
||||
target_type is int
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return int(value)
|
||||
if (
|
||||
target_type is float
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return float(value)
|
||||
if target_type is str and isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
raise TypeError
|
||||
|
|
@ -1,12 +1,14 @@
|
|||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
from typing import Any, Dict, Optional, Self
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
"""Field-aware JSON load/save for dataclass configs.
|
||||
class BaseModelConfig(BaseConfig):
|
||||
"""Field-aware JSON from/to file for dataclass configs.
|
||||
|
||||
Subclass with additional fields. The base ``model_type`` field
|
||||
enables ``AutoModel`` to pick the correct subclass.
|
||||
|
|
@ -14,76 +16,25 @@ class BaseModelConfig:
|
|||
|
||||
model_type: Optional[str] = None
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
raw: Dict[str, Any] = {}
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str) -> Self:
|
||||
with open(config_path, "r") as f:
|
||||
raw.update(json.load(f))
|
||||
raw: Dict[str, Any] = json.load(f)
|
||||
|
||||
hints = get_type_hints(type(self))
|
||||
valid = {fld.name for fld in fields(self)}
|
||||
for key, value in raw.items():
|
||||
valid = {fld.name for fld in fields(cls)}
|
||||
for key in list(raw):
|
||||
if key not in valid:
|
||||
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
|
||||
continue
|
||||
warnings.warn(f"Unknown config key '{key}'")
|
||||
del raw[key]
|
||||
|
||||
target_type = self._unwrap_optional(hints.get(key))
|
||||
if target_type is None:
|
||||
continue
|
||||
return cls.from_dict(raw)
|
||||
|
||||
try:
|
||||
value = self._coerce(value, target_type)
|
||||
except (TypeError, ValueError):
|
||||
sys.stderr.write(
|
||||
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
|
||||
)
|
||||
continue
|
||||
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str):
|
||||
config_dict: Dict[str, Any] = {}
|
||||
for fld in fields(self):
|
||||
v = getattr(self, fld.name)
|
||||
if v is not None:
|
||||
config_dict[fld.name] = v
|
||||
def to_file(self, config_path: str):
|
||||
d = self.to_dict()
|
||||
config_dict = {k: v for k, v in d.items() if v is not None}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_optional(tp: type) -> Optional[type]:
|
||||
if tp is None:
|
||||
return None
|
||||
origin = getattr(tp, "__origin__", None)
|
||||
if origin is not None:
|
||||
args = getattr(tp, "__args__", ())
|
||||
non_none = [a for a in args if a is not type(None)]
|
||||
return non_none[0] if non_none else None
|
||||
return tp
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value: Any, target_type: type) -> Any:
|
||||
if target_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if (
|
||||
target_type is int
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return int(value)
|
||||
if (
|
||||
target_type is float
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return float(value)
|
||||
if target_type is str and isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
raise TypeError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(BaseModelConfig):
|
||||
|
|
@ -106,6 +57,11 @@ class ModelConfig(BaseModelConfig):
|
|||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
# MLA
|
||||
kv_lora_rank: Optional[int] = None
|
||||
qk_nope_head_dim: Optional[int] = None
|
||||
qk_rope_head_dim: Optional[int] = None
|
||||
|
||||
# MoE
|
||||
ffn_type: str = "mlp"
|
||||
n_routed_experts: Optional[int] = None
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||
|
|
@ -20,8 +22,10 @@ class TrainConfig:
|
|||
default=None, metadata={"help": "Scheduler factory for training."}
|
||||
)
|
||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
||||
accumulation_steps: int = field(
|
||||
batch_per_device: int = field(
|
||||
default=4, metadata={"help": "Batch size per device."}
|
||||
)
|
||||
grad_accum_steps: int = field(
|
||||
default=1, metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from astrai.dataset.storage import (
|
|||
H5Storage,
|
||||
JSONStorage,
|
||||
MultiSegmentFetcher,
|
||||
available_storage_types,
|
||||
create_storage,
|
||||
StorageFactory,
|
||||
detect_format,
|
||||
load_h5,
|
||||
load_json,
|
||||
|
|
@ -26,9 +25,8 @@ __all__ = [
|
|||
"BaseStorage",
|
||||
"H5Storage",
|
||||
"JSONStorage",
|
||||
"create_storage",
|
||||
"StorageFactory",
|
||||
"detect_format",
|
||||
"available_storage_types",
|
||||
"save_h5",
|
||||
"load_h5",
|
||||
"save_json",
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from torch.utils.data import Dataset
|
|||
|
||||
from astrai.dataset.storage import (
|
||||
BaseStorage,
|
||||
create_storage,
|
||||
StorageFactory,
|
||||
detect_format,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
|
|
@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC):
|
|||
"""
|
||||
if storage_type is None:
|
||||
storage_type = detect_format(load_path)
|
||||
self.storage = create_storage(storage_type)
|
||||
self.storage = StorageFactory.create(storage_type)
|
||||
self.storage.load(load_path, tokenizer=tokenizer)
|
||||
|
||||
def load_json(self, load_path: str, tokenizer=None):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ import h5py
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
|
|
@ -258,6 +260,24 @@ class BaseStorage(ABC):
|
|||
return self._fetcher.multi_keys
|
||||
|
||||
|
||||
class StorageFactory(BaseFactory["BaseStorage"]):
|
||||
"""Factory for creating storage backends by type name.
|
||||
|
||||
Example:
|
||||
@StorageFactory.register("custom")
|
||||
class CustomStorage(BaseStorage):
|
||||
...
|
||||
|
||||
storage = StorageFactory.create("custom")
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, storage_cls: type) -> None:
|
||||
if not issubclass(storage_cls, BaseStorage):
|
||||
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
||||
|
||||
|
||||
@StorageFactory.register("h5")
|
||||
class H5Storage(BaseStorage):
|
||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||
|
||||
|
|
@ -266,6 +286,7 @@ class H5Storage(BaseStorage):
|
|||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
@StorageFactory.register("json")
|
||||
class JSONStorage(BaseStorage):
|
||||
"""JSON-based storage backend.
|
||||
|
||||
|
|
@ -278,35 +299,3 @@ class JSONStorage(BaseStorage):
|
|||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_json(load_path, tokenizer=tokenizer)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
_STORAGE_REGISTRY: Dict[str, type] = {
|
||||
"h5": H5Storage,
|
||||
"json": JSONStorage,
|
||||
}
|
||||
|
||||
|
||||
def create_storage(storage_type: str) -> BaseStorage:
|
||||
"""Create a storage instance by type name.
|
||||
|
||||
Args:
|
||||
storage_type: Storage type name ("h5", "json")
|
||||
|
||||
Returns:
|
||||
Storage instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the storage type is unknown
|
||||
"""
|
||||
storage_cls = _STORAGE_REGISTRY.get(storage_type)
|
||||
if storage_cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown storage type: '{storage_type}'. "
|
||||
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
|
||||
)
|
||||
return storage_cls()
|
||||
|
||||
|
||||
def available_storage_types() -> List[str]:
|
||||
"""Return list of registered storage type names."""
|
||||
return sorted(_STORAGE_REGISTRY.keys())
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
|
|
@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]):
|
|||
def create(cls, name: str, *args, **kwargs) -> T:
|
||||
"""Create a component instance by name.
|
||||
|
||||
Filters kwargs to match the component's __init__ signature,
|
||||
so components don't need to declare **kwargs just to absorb
|
||||
parameters meant for other components.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
*args: Positional arguments passed to component constructor
|
||||
|
|
@ -139,6 +144,17 @@ class BaseFactory(ABC, Generic[T]):
|
|||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
)
|
||||
component_cls = cls._registry.get(name)
|
||||
sig = inspect.signature(component_cls.__init__)
|
||||
has_var_kwargs = any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||
)
|
||||
if not has_var_kwargs:
|
||||
valid = {
|
||||
p.name
|
||||
for p in sig.parameters.values()
|
||||
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
||||
}
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
||||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -163,4 +163,5 @@ def run_server(
|
|||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,14 +22,22 @@ class InferenceScheduler:
|
|||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prompt_len: int = 512,
|
||||
max_prompt_len: int = 2048,
|
||||
page_size: int = 64,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
config = model.config
|
||||
|
||||
self.max_seq_len = max_seq_len or config.max_len
|
||||
if max_seq_len is not None:
|
||||
self.max_seq_len = max_seq_len
|
||||
elif config.max_len is not None:
|
||||
self.max_seq_len = config.max_len
|
||||
else:
|
||||
raise ValueError(
|
||||
"max_seq_len must be provided either as argument "
|
||||
"or in model config (config.max_len)"
|
||||
)
|
||||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
|
|
|
|||
|
|
@ -60,10 +60,9 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
model_path = Path(path)
|
||||
|
||||
# Load config
|
||||
config = ModelConfig()
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
config.load(str(config_path))
|
||||
config = ModelConfig.from_file(str(config_path))
|
||||
else:
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
|
|
@ -89,7 +88,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
self.config.save(str(save_path / "config.json"))
|
||||
self.config.to_file(str(save_path / "config.json"))
|
||||
|
||||
# Save weights
|
||||
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ class GQA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % n_heads == 0
|
||||
|
|
@ -123,7 +122,6 @@ class MLA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -143,7 +141,7 @@ class MLA(nn.Module):
|
|||
|
||||
self.kv_b_proj = Linear(
|
||||
kv_lora_rank,
|
||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||
n_kv_heads * (2 * self.head_dim),
|
||||
)
|
||||
|
||||
self.o_proj = Linear(dim, dim, bias=False)
|
||||
|
|
@ -176,7 +174,7 @@ class MLA(nn.Module):
|
|||
|
||||
q_nope, q_rope = (
|
||||
q[..., : self.qk_nope_head_dim],
|
||||
q[..., self.qk_rope_head_dim :],
|
||||
q[..., self.qk_nope_head_dim :],
|
||||
)
|
||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||
|
|
|
|||
|
|
@ -16,13 +16,13 @@ class DecoderBlock(nn.Module):
|
|||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
norm_eps: float,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
attn_type: str = "gqa",
|
||||
ffn_type: str = "mlp",
|
||||
**moe_kwargs,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = AttnFactory.create(
|
||||
|
|
@ -34,10 +34,11 @@ class DecoderBlock(nn.Module):
|
|||
norm_eps=norm_eps,
|
||||
use_gated_attention=use_gated_attention,
|
||||
layer_id=layer_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
|
||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
|
|||
|
||||
@FFNFactory.register("mlp")
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
|
||||
def __init__(self, dim: int, dim_ffn: int):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
self.up = Linear(dim, dim_ffn)
|
||||
self.gate = Linear(dim, dim_ffn)
|
||||
self.down = Linear(dim_ffn, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
|
|
@ -32,12 +32,11 @@ class DeepSeekMoE(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_feed_forward: int,
|
||||
dim_ffn: int,
|
||||
n_routed_experts: int,
|
||||
n_shared_experts: int = 1,
|
||||
n_activated_experts: int = 2,
|
||||
topk_method: str = "greedy",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -49,10 +48,10 @@ class DeepSeekMoE(nn.Module):
|
|||
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||
|
||||
self.shared_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
|
||||
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
||||
)
|
||||
self.routed_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
|
||||
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
|
|
|
|||
|
|
@ -53,9 +53,13 @@ class Transformer(AutoModel):
|
|||
def __init__(self, config: ModelConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
config.dim // config.n_heads, config.max_len
|
||||
rope_dim = (
|
||||
config.qk_rope_head_dim
|
||||
if config.attn_type == "mla"
|
||||
else config.dim // config.n_heads
|
||||
)
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
|
@ -75,6 +79,9 @@ class Transformer(AutoModel):
|
|||
n_shared_experts=config.n_shared_experts,
|
||||
n_activated_experts=config.n_activated_experts,
|
||||
topk_method=config.moe_topk_method,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
|
|
@ -83,7 +90,7 @@ class Transformer(AutoModel):
|
|||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||
|
||||
if self.config.tie_weight:
|
||||
if self.config.tie_weight is True:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
self._init_weights()
|
||||
|
|
@ -99,7 +106,7 @@ class Transformer(AutoModel):
|
|||
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if self.config.tie_weight:
|
||||
if self.config.tie_weight is True:
|
||||
# same tensor for embed and lm_head
|
||||
if embed_key in state_dict:
|
||||
state_dict[lm_head_key] = state_dict[embed_key]
|
||||
|
|
@ -115,7 +122,7 @@ class Transformer(AutoModel):
|
|||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
|
||||
if self.config.tie_weight:
|
||||
if self.config.tie_weight is True:
|
||||
lm_head_key = prefix + "lm_head.weight"
|
||||
if lm_head_key in state_dict:
|
||||
del state_dict[lm_head_key]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
|
@ -16,11 +17,13 @@ class Checkpoint:
|
|||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.state_dict = state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.extra = extra or {}
|
||||
self.meta = meta or {}
|
||||
|
||||
def save(
|
||||
self,
|
||||
|
|
@ -35,13 +38,16 @@ class Checkpoint:
|
|||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
meta.update(self.meta)
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||
if self.extra:
|
||||
torch.save(self.extra, save_path / "extra.pt")
|
||||
for key, value in self.extra.items():
|
||||
torch.save(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
|
@ -64,14 +70,14 @@ class Checkpoint:
|
|||
|
||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||
|
||||
extra = None
|
||||
extra_path = save_path / "extra.pt"
|
||||
if extra_path.exists():
|
||||
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
||||
extra = {}
|
||||
for f in save_path.iterdir():
|
||||
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
extra=extra,
|
||||
extra=extra or None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,8 +79,7 @@ class GradientClippingCallback(TrainCallback):
|
|||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
_ = context
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
|
|
@ -90,6 +89,8 @@ class CheckpointCallback(TrainCallback):
|
|||
Checkpoint callback for trainer.
|
||||
"""
|
||||
|
||||
extra_keys = ("optimizer", "scheduler")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
|
|
@ -97,12 +98,14 @@ class CheckpointCallback(TrainCallback):
|
|||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.save_extra_fn = save_extra_fn
|
||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
@ -116,17 +119,22 @@ class CheckpointCallback(TrainCallback):
|
|||
else context.model.state_dict()
|
||||
)
|
||||
|
||||
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
meta=context.config.to_dict(),
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
self.load_extra_fn(context.checkpoint.extra, context)
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
self._save_checkpoint(context)
|
||||
|
|
@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback):
|
|||
def on_error(self, context: TrainContext):
|
||||
self._save_checkpoint(context)
|
||||
|
||||
@staticmethod
|
||||
def save_extra(context: TrainContext) -> dict:
|
||||
extra = {}
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
obj = getattr(context, name, None)
|
||||
if obj:
|
||||
extra[name] = obj.state_dict()
|
||||
return extra
|
||||
|
||||
@staticmethod
|
||||
def load_extra(extra: dict, context: TrainContext):
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
if name in extra:
|
||||
getattr(context, name).load_state_dict(extra[name])
|
||||
|
||||
|
||||
@CallbackFactory.register("progress_bar")
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional, Self
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -21,6 +21,7 @@ class TrainContext:
|
|||
optimizer: Optimizer = field(default=None)
|
||||
scheduler: LRScheduler = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
config: TrainConfig = field(default=None)
|
||||
|
||||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
|
|
@ -35,11 +36,9 @@ class TrainContextBuilder:
|
|||
def __init__(
|
||||
self,
|
||||
config: TrainConfig,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.config = config
|
||||
self._checkpoint: Optional[Checkpoint] = None
|
||||
self._load_extra_fn = load_extra_fn
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._checkpoint = checkpoint
|
||||
|
|
@ -50,6 +49,7 @@ class TrainContextBuilder:
|
|||
model=self.config.model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
device = get_current_device()
|
||||
|
|
@ -71,11 +71,8 @@ class TrainContextBuilder:
|
|||
context.optimizer = self.config.optimizer_fn(context.model)
|
||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||
|
||||
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
||||
self._load_extra_fn(self._checkpoint.extra, context)
|
||||
|
||||
cfg = self.config
|
||||
sampler_offset = context.iteration * cfg.batch_size
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.dataset,
|
||||
start_epoch=context.epoch,
|
||||
|
|
@ -84,7 +81,7 @@ class TrainContextBuilder:
|
|||
)
|
||||
context.dataloader = DataLoader(
|
||||
cfg.dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=cfg.pin_memory,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from itertools import batched
|
||||
from typing import List, Optional
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
|
|
@ -33,11 +32,6 @@ class Trainer:
|
|||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
]
|
||||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
return (
|
||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||
)
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
|
|
@ -45,49 +39,47 @@ class Trainer:
|
|||
method(context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
config = self.train_config
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._train_impl,
|
||||
backend=config.backend,
|
||||
world_size=config.nprocs,
|
||||
master_addr=config.master_addr,
|
||||
master_port=config.master_port,
|
||||
device_type=config.device_type,
|
||||
backend=cfg.backend,
|
||||
world_size=cfg.nprocs,
|
||||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||
context = self._build_context(checkpoint)
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
||||
try:
|
||||
context.model.train()
|
||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||
grad_accum_steps = cfg.grad_accum_steps
|
||||
|
||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||
for epoch in range(context.epoch, cfg.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for steps in batched(context.dataloader, accumulation_steps):
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
for batch in context.dataloader:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / grad_accum_steps
|
||||
stand_loss.backward()
|
||||
context.iteration += 1
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
step_batch_nums = len(steps)
|
||||
for batch in steps:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
context.iteration += 1
|
||||
if context.iteration % grad_accum_steps == 0:
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
stand_loss = loss / step_batch_nums
|
||||
stand_loss.backward()
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
self._call_callbacks("on_step_end", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|||
|
||||
|
||||
def generate_text():
|
||||
# Load model from pretrained
|
||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
|
@ -22,16 +21,15 @@ def generate_text():
|
|||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
response = engine.generate(
|
||||
for token in engine.generate(
|
||||
prompt=query,
|
||||
stream=False,
|
||||
stream=True,
|
||||
max_tokens=2048,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
)
|
||||
|
||||
print(response)
|
||||
):
|
||||
print(token, end="", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -42,18 +42,20 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
|
||||
parser.add_argument(
|
||||
"--accumulation_steps",
|
||||
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_accum_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations between each optimizer step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of warmup steps for LR scheduler.",
|
||||
"--warmup_ratio",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Fraction of total steps used for LR warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||
|
|
@ -67,13 +69,13 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--adamw_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
default=0.95,
|
||||
help="Beta values for AdamW optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adamw_beta2",
|
||||
type=float,
|
||||
default=0.95,
|
||||
default=0.99,
|
||||
help="Beta values for AdamW optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -114,7 +116,7 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--label_smoothing",
|
||||
type=float,
|
||||
default=0.1,
|
||||
default=0.05,
|
||||
help="cross_entropy function label smoothing parameter",
|
||||
)
|
||||
|
||||
|
|
@ -181,17 +183,34 @@ def prepare_checkpoint(model: nn.Module) -> dict:
|
|||
return model.module.state_dict()
|
||||
|
||||
|
||||
def compute_total_steps(
|
||||
dataset_len: int,
|
||||
n_epoch: int,
|
||||
batch_per_device: int,
|
||||
nprocs: int,
|
||||
grad_accum_steps: int,
|
||||
) -> int:
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
samples_per_replica = ceil_div(dataset_len, nprocs)
|
||||
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
|
||||
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
|
||||
return total_steps
|
||||
|
||||
|
||||
def train(
|
||||
train_type: str,
|
||||
param_path: str,
|
||||
data_root_path: str,
|
||||
max_lr: float,
|
||||
n_epoch: int,
|
||||
batch_size: int,
|
||||
batch_per_device: int,
|
||||
start_epoch: int,
|
||||
start_batch: int,
|
||||
accumulation_steps: int,
|
||||
warmup_steps: int,
|
||||
grad_accum_steps: int,
|
||||
warmup_ratio: float,
|
||||
ckpt_interval: int,
|
||||
ckpt_dir: str,
|
||||
dpo_beta: float,
|
||||
|
|
@ -216,10 +235,8 @@ def train(
|
|||
assert os.path.exists(param_path)
|
||||
|
||||
# Load config
|
||||
config = ModelConfig()
|
||||
config_path = os.path.join(param_path, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
config.load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
|
@ -260,13 +277,17 @@ def train(
|
|||
},
|
||||
)
|
||||
|
||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||
total_steps = compute_total_steps(
|
||||
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
|
||||
)
|
||||
warmup_steps = int(warmup_ratio * total_steps)
|
||||
|
||||
scheduler_fn = partial(
|
||||
create_scheduler,
|
||||
**{
|
||||
"schedule_type": "cosine",
|
||||
"warmup_steps": warmup_steps,
|
||||
"lr_decay_steps": total_steps - warmup_steps,
|
||||
"warmup_steps": min(warmup_steps, total_steps),
|
||||
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -278,11 +299,11 @@ def train(
|
|||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=ckpt_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
batch_per_device=batch_per_device,
|
||||
start_epoch=start_epoch,
|
||||
start_batch=start_batch,
|
||||
ckpt_interval=ckpt_interval,
|
||||
accumulation_steps=accumulation_steps,
|
||||
grad_accum_steps=grad_accum_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
num_workers=num_workers,
|
||||
|
|
|
|||
|
|
@ -107,12 +107,12 @@ def test_model():
|
|||
"""Session-scoped small Transformer model, created once."""
|
||||
config = ModelConfig(
|
||||
vocab_size=1000,
|
||||
dim=16,
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
dim_ffn=32,
|
||||
max_len=1024,
|
||||
n_layers=4,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
n_kv_heads=1,
|
||||
dim_ffn=16,
|
||||
max_len=64,
|
||||
n_layers=2,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
|
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
|
|||
json.dump(
|
||||
{
|
||||
"vocab_size": 1000,
|
||||
"dim": 16,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 32,
|
||||
"max_len": 1024,
|
||||
"n_layers": 4,
|
||||
"dim": 8,
|
||||
"n_heads": 2,
|
||||
"n_kv_heads": 1,
|
||||
"dim_ffn": 16,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5,
|
||||
},
|
||||
f,
|
||||
|
|
|
|||
|
|
@ -35,6 +35,33 @@ def test_single_process():
|
|||
assert loaded_checkpoint.iteration == 30
|
||||
|
||||
|
||||
def test_checkpoint_with_extra():
|
||||
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
|
||||
extra = {
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": {"last_epoch": 5},
|
||||
}
|
||||
checkpoint = Checkpoint(
|
||||
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
|
||||
import os
|
||||
|
||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||
|
||||
loaded = Checkpoint.load(tmpdir)
|
||||
assert loaded.extra["scheduler"]["last_epoch"] == 5
|
||||
assert "state" in loaded.extra["optimizer"]
|
||||
|
||||
|
||||
def simple_training():
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from astrai.dataset.storage import (
|
|||
BaseSegmentFetcher,
|
||||
H5Storage,
|
||||
MultiSegmentFetcher,
|
||||
create_storage,
|
||||
StorageFactory,
|
||||
detect_format,
|
||||
load_json,
|
||||
save_h5,
|
||||
|
|
@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env):
|
|||
|
||||
|
||||
def test_create_storage_invalid_type():
|
||||
"""create_storage raises ValueError for unknown type"""
|
||||
with pytest.raises(ValueError, match="Unknown storage type"):
|
||||
create_storage("parquet")
|
||||
"""StorageFactory.create raises ValueError for unknown type"""
|
||||
with pytest.raises(ValueError, match="Unknown component"):
|
||||
StorageFactory.create("parquet")
|
||||
|
||||
|
||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
TINY_CONFIG = dict(
|
||||
vocab_size=128,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
n_kv_heads=1,
|
||||
dim_ffn=16,
|
||||
max_len=64,
|
||||
n_layers=2,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
CONFIGS = [
|
||||
pytest.param(
|
||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
|
||||
id="gqa_mlp",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
**TINY_CONFIG,
|
||||
"attn_type": "mla",
|
||||
"ffn_type": "mlp",
|
||||
"kv_lora_rank": 4,
|
||||
"qk_nope_head_dim": 2,
|
||||
"qk_rope_head_dim": 2,
|
||||
},
|
||||
id="mla_mlp",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
**TINY_CONFIG,
|
||||
"attn_type": "gqa",
|
||||
"ffn_type": "moe",
|
||||
"n_routed_experts": 4,
|
||||
"n_shared_experts": 1,
|
||||
"n_activated_experts": 2,
|
||||
"moe_topk_method": "greedy",
|
||||
},
|
||||
id="gqa_moe",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
**TINY_CONFIG,
|
||||
"attn_type": "gqa",
|
||||
"ffn_type": "mlp",
|
||||
"rope_theta": 100000.0,
|
||||
},
|
||||
id="gqa_rope_theta",
|
||||
),
|
||||
pytest.param(
|
||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
|
||||
id="gqa_qk_norm",
|
||||
),
|
||||
pytest.param(
|
||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
|
||||
id="gqa_tie_weight",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert "logits" in output
|
||||
assert "hidden_states" in output
|
||||
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
|
||||
assert not torch.isnan(output["logits"]).any()
|
||||
assert not torch.isnan(output["hidden_states"]).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward_with_padding(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||
input_mask[:, 4:] = False
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, input_mask=input_mask)
|
||||
|
||||
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert not torch.isnan(output["logits"]).any()
|
||||
|
|
@ -17,10 +17,10 @@ def transformer_test_env():
|
|||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"dim": 128,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 256,
|
||||
"dim": 8,
|
||||
"n_heads": 2,
|
||||
"n_kv_heads": 1,
|
||||
"dim_ffn": 16,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5,
|
||||
|
|
@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
|
||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
original_model = Transformer(config)
|
||||
|
||||
st.save_file(original_model.state_dict(), model_path)
|
||||
|
||||
loaded_config = ModelConfig().load(config_path)
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
loaded_config = ModelConfig().load(config_path)
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ def create_train_config(
|
|||
device: str,
|
||||
strategy: str = "seq",
|
||||
n_epoch: int = 1,
|
||||
batch_size: int = 2,
|
||||
accumulation_steps: int = 1,
|
||||
batch_per_device: int = 2,
|
||||
grad_accum_steps: int = 1,
|
||||
max_grad_norm: float = 1.0,
|
||||
ckpt_interval: int = 5,
|
||||
random_seed: int = 42,
|
||||
|
|
@ -47,8 +47,8 @@ def create_train_config(
|
|||
device: Device type ("cuda" or "cpu")
|
||||
strategy: Training strategy type (default: "seq")
|
||||
n_epoch: Number of epochs (default: 1)
|
||||
batch_size: Batch size (default: 2)
|
||||
accumulation_steps: Gradient accumulation steps (default: 1)
|
||||
batch_per_device: Batch size per device (default: 2)
|
||||
grad_accum_steps: Gradient accumulation steps (default: 1)
|
||||
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
||||
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
||||
random_seed: Random seed for reproducibility (default: 42)
|
||||
|
|
@ -74,9 +74,9 @@ def create_train_config(
|
|||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=test_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
batch_per_device=batch_per_device,
|
||||
ckpt_interval=ckpt_interval,
|
||||
accumulation_steps=accumulation_steps,
|
||||
grad_accum_steps=grad_accum_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
device_type=device,
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=3,
|
||||
accumulation_steps=1,
|
||||
grad_accum_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"],
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
dataset=early_stopping_dataset,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=1,
|
||||
accumulation_steps=2,
|
||||
grad_accum_steps=2,
|
||||
random_seed=np.random.randint(1e4),
|
||||
device_type=base_test_env["device"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
|||
"""Test training with different batch sizes"""
|
||||
batch_sizes = [1, 2, 4, 8]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for batch_per_device in batch_sizes:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
batch_size=batch_size,
|
||||
batch_per_device=batch_per_device,
|
||||
)
|
||||
|
||||
assert train_config.batch_size == batch_size
|
||||
assert train_config.batch_per_device == batch_per_device
|
||||
|
||||
|
||||
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
||||
"""Test training with different gradient accumulation steps"""
|
||||
accumulation_steps_list = [1, 2, 4]
|
||||
grad_accum_steps_list = [1, 2, 4]
|
||||
|
||||
for accumulation_steps in accumulation_steps_list:
|
||||
for grad_accum_steps in grad_accum_steps_list:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
batch_size=2,
|
||||
accumulation_steps=accumulation_steps,
|
||||
batch_per_device=2,
|
||||
grad_accum_steps=grad_accum_steps,
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
|
||||
assert train_config.accumulation_steps == accumulation_steps
|
||||
assert train_config.grad_accum_steps == grad_accum_steps
|
||||
|
||||
|
||||
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
||||
"""Test training with memory-efficient configurations"""
|
||||
# Test with smaller batch sizes and gradient checkpointing
|
||||
small_batch_configs = [
|
||||
{"batch_size": 1, "accumulation_steps": 8},
|
||||
{"batch_size": 2, "accumulation_steps": 4},
|
||||
{"batch_size": 4, "accumulation_steps": 2},
|
||||
{"batch_per_device": 1, "grad_accum_steps": 8},
|
||||
{"batch_per_device": 2, "grad_accum_steps": 4},
|
||||
{"batch_per_device": 4, "grad_accum_steps": 2},
|
||||
]
|
||||
|
||||
for config in small_batch_configs:
|
||||
|
|
@ -54,8 +54,9 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
|||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
batch_size=config["batch_size"],
|
||||
accumulation_steps=config["accumulation_steps"],
|
||||
batch_per_device=config["batch_per_device"],
|
||||
grad_accum_steps=config["grad_accum_steps"],
|
||||
)
|
||||
|
||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||
assert train_config.grad_accum_steps == config["grad_accum_steps"]
|
||||
assert train_config.batch_per_device == config["batch_per_device"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue