refactor: 训练循环改为两重迭代并统一参数命名
- 训练循环从三重(epoch→batched→batch)改为二重(epoch→batch) - batch_size → batch_per_device, accumulation_steps → grad_accum_steps - scheduler 移入 step block 对齐 optimizer 更新步 - GradientClippingCallback 改用 on_step_begin 避免零梯度裁剪 - 移除 _train_impl 误导性的 -> Checkpoint 标注 - total_steps 修除为向下取整并精简为一行 - warmup_steps 改为 warmup_ratio (默认0.05)
This commit is contained in:
parent
7dea929788
commit
d7a7f570ed
30
README.md
30
README.md
|
|
@ -78,15 +78,27 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
|
|||
#### Train a Model
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=sft \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.99 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=1e-5 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.1 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||
|
|
|
|||
|
|
@ -84,15 +84,27 @@ python scripts/demo/download.py
|
|||
#### 训练模型
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=sft \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.99 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=1e-5 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.1 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
完整参数列表见[参数说明](./params.md)。
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ classDiagram
|
|||
+int n_shared_experts
|
||||
+int n_activated_experts
|
||||
+str moe_topk_method
|
||||
+Optional[int] kv_lora_rank
|
||||
+Optional[int] qk_nope_head_dim
|
||||
+Optional[int] qk_rope_head_dim
|
||||
+load(config_path) ModelConfig
|
||||
+save(config_path)
|
||||
}
|
||||
|
|
@ -41,8 +44,8 @@ classDiagram
|
|||
+Callable optimizer_fn
|
||||
+Callable scheduler_fn
|
||||
+int n_epoch
|
||||
+int batch_size
|
||||
+int accumulation_steps
|
||||
+int batch_per_device
|
||||
+int grad_accum_steps
|
||||
+float max_grad_norm
|
||||
+int start_epoch
|
||||
+int start_batch
|
||||
|
|
@ -69,7 +72,7 @@ classDiagram
|
|||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+BaseStorage storage
|
||||
+Optional[BaseStorage] storage
|
||||
+load(load_path, storage_type, tokenizer)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
|
|
@ -126,8 +129,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class ResumableDistributedSampler {
|
||||
+int start_epoch
|
||||
+int start_iter
|
||||
+int epoch
|
||||
+int iter
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
|
|
@ -155,7 +158,7 @@ classDiagram
|
|||
+Registry _registry
|
||||
+register(model_type) decorator
|
||||
+get_component_class(model_type) Type
|
||||
+from_pretrained(path, disable_random_init) nn.Module
|
||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||
+save_pretrained(save_directory)
|
||||
+to(*args, **kwargs) Self
|
||||
}
|
||||
|
|
@ -167,7 +170,7 @@ classDiagram
|
|||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
+Linear lm_head
|
||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
|
||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
||||
+load_state_dict(state_dict)
|
||||
+state_dict()
|
||||
}
|
||||
|
|
@ -185,6 +188,7 @@ classDiagram
|
|||
+int n_kv_heads
|
||||
+int head_dim
|
||||
+int n_rep
|
||||
+int layer_id
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Linear q_proj, k_proj, v_proj, o_proj
|
||||
|
|
@ -201,6 +205,7 @@ classDiagram
|
|||
+int qk_nope_head_dim
|
||||
+int qk_rope_head_dim
|
||||
+int n_rep
|
||||
+int layer_id
|
||||
+bool use_gated_attention
|
||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||
+Linear o_proj
|
||||
|
|
@ -215,6 +220,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class DeepSeekMoE {
|
||||
+int dim
|
||||
+int n_routed_experts
|
||||
+int n_shared_experts
|
||||
+int n_activated_experts
|
||||
|
|
@ -236,6 +242,7 @@ classDiagram
|
|||
class RMSNorm {
|
||||
+Parameter weight
|
||||
+float norm_eps
|
||||
+tuple normalized_shape
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
|
|
@ -299,7 +306,6 @@ classDiagram
|
|||
+TrainConfig train_config
|
||||
+List[TrainCallback] callbacks
|
||||
+train(checkpoint)
|
||||
+_build_context(checkpoint) TrainContext
|
||||
+_get_default_callbacks() List[TrainCallback]
|
||||
}
|
||||
|
||||
|
|
@ -324,7 +330,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class BaseStrategy {
|
||||
+nn.Module model
|
||||
+Union[Callable, nn.Module] model
|
||||
+str device
|
||||
+compute_loss(batch) Tensor
|
||||
}
|
||||
|
|
@ -332,7 +338,7 @@ classDiagram
|
|||
class StrategyFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(model, train_type, device, **kwargs) BaseStrategy
|
||||
+create(train_type, model, device, **kwargs) BaseStrategy
|
||||
}
|
||||
|
||||
class SEQStrategy {
|
||||
|
|
@ -400,7 +406,7 @@ classDiagram
|
|||
|
||||
class GradientClippingCallback {
|
||||
+float max_grad_norm
|
||||
+on_step_end(context)
|
||||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class CheckpointCallback {
|
||||
|
|
@ -459,10 +465,7 @@ classDiagram
|
|||
+TaskManager _task_mgr
|
||||
+bool _running
|
||||
+Thread _loop_thread
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
+int page_size
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
+remove_task(task_id)
|
||||
+start()
|
||||
|
|
@ -500,10 +503,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class Storage {
|
||||
+int n_layers
|
||||
+int page_size
|
||||
+int head_dim
|
||||
+int n_kv_heads
|
||||
+Tensor k_cache
|
||||
+Tensor v_cache
|
||||
+write(layer_id, page_table, start_pos, k, v)
|
||||
|
|
@ -675,7 +675,6 @@ classDiagram
|
|||
}
|
||||
|
||||
class AnthropicHandler {
|
||||
+List[str] stop_sequences
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
|
|
@ -704,7 +703,7 @@ classDiagram
|
|||
|
||||
namespace parallel {
|
||||
class Functions {
|
||||
+spawn_parallel_fn(fn, nprocs)
|
||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
+get_current_device() str
|
||||
+get_world_size() int
|
||||
|
|
@ -878,4 +877,4 @@ classDiagram
|
|||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@
|
|||
| `--data_root_path` | Dataset root directory | required |
|
||||
| `--param_path` | Model parameters or checkpoint path | required |
|
||||
| `--n_epoch` | Total training epochs | 1 |
|
||||
| `--batch_size` | Batch size | 1 |
|
||||
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||
| `--batch_per_device` | Batch size per device | 1 |
|
||||
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||
|
||||
### Learning Rate Scheduling
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--warmup_steps` | Warmup steps | 1000 |
|
||||
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||
|
||||
|
|
@ -69,90 +69,29 @@
|
|||
### Usage Example
|
||||
|
||||
```bash
|
||||
python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--n_epoch 3 \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 2000 \
|
||||
--max_grad_norm 1.0 \
|
||||
--ckpt_interval 5000 \
|
||||
--ckpt_dir ./checkpoints \
|
||||
--num_workers 4 \
|
||||
--nprocs 1 \
|
||||
--device_type cuda
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=sft \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.99 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=1e-5 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.1 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Generation Parameters
|
||||
|
||||
### GenerationRequest Parameters
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `messages` | List of message dictionaries (role, content) | required |
|
||||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||
| `top_k` | Top-k sampling count | 50 |
|
||||
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
|
||||
| `stream` | Whether to stream output | False |
|
||||
|
||||
### Usage Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.inference import InferenceEngine, GenerationRequest
|
||||
|
||||
# Load model using AutoModel
|
||||
model = AutoModel.from_pretrained("your_model_dir")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
||||
|
||||
# Create engine with separate model and tokenizer
|
||||
engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Build request with messages format
|
||||
request = GenerationRequest(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
# Generate (streaming)
|
||||
for token in engine.generate_with_request(request):
|
||||
print(token, end="", flush=True)
|
||||
|
||||
# Or use simple generate interface
|
||||
result = engine.generate(
|
||||
prompt="Hello",
|
||||
stream=False,
|
||||
max_tokens=1024,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
)
|
||||
```
|
||||
|
||||
### Generation Modes
|
||||
|
||||
| Mode | Description |
|
||||
|------|-------------|
|
||||
| `stream=True` | Streaming output, yields token by token |
|
||||
| `stream=False` | Non-streaming output, returns complete result |
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
|
|||
|
||||
## Training Loop
|
||||
|
||||
Nested loop: **epoch** → **step** (accumulation window) → **batch**.
|
||||
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
||||
|
||||
```
|
||||
on_train_begin
|
||||
on_epoch_begin
|
||||
for steps in batched(dataloader, accumulation_steps):
|
||||
on_step_begin
|
||||
step_batch_nums = len(steps)
|
||||
for batch in steps:
|
||||
on_batch_begin
|
||||
loss = strategy(batch)
|
||||
(loss / step_batch_nums).backward()
|
||||
iteration += 1
|
||||
on_batch_end
|
||||
on_step_end
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
for batch in dataloader:
|
||||
on_batch_begin
|
||||
loss = strategy(batch)
|
||||
(loss / grad_accum_steps).backward()
|
||||
iteration += 1
|
||||
on_batch_end
|
||||
|
||||
if iteration % grad_accum_steps == 0:
|
||||
on_step_begin
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
on_step_end
|
||||
scheduler.step()
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
|
@ -91,9 +91,9 @@ on_train_end
|
|||
|
||||
| Hook | Fires | Default callback |
|
||||
|------|-------|-----------------|
|
||||
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
|
||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
|
||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||
|
||||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||||
|
||||
|
|
@ -162,7 +162,7 @@ Checkpoint(state_dict, epoch, iteration, extra)
|
|||
└── load(save_dir) broadcasts metadata from rank-0
|
||||
```
|
||||
|
||||
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
|
||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||
|
||||
## TrainContextBuilder (Builder Pattern)
|
||||
|
||||
|
|
@ -183,17 +183,29 @@ context = (
|
|||
## Training CLI
|
||||
|
||||
```bash
|
||||
python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/data \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=sft \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
--grad_accum_steps=8 \
|
||||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.99 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_weight_decay=1e-5 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=./checkpoint \
|
||||
--random_seed=3407 \
|
||||
--label_smoothing=0.1 \
|
||||
> out.log 2> err.log &
|
||||
```
|
||||
|
||||
Full parameter reference at [params.md](params.md).
|
||||
|
||||
> Document Update Time: 2026-05-15
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
|
|||
|
|
@ -20,8 +20,10 @@ class TrainConfig:
|
|||
default=None, metadata={"help": "Scheduler factory for training."}
|
||||
)
|
||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
||||
accumulation_steps: int = field(
|
||||
batch_per_device: int = field(
|
||||
default=4, metadata={"help": "Batch size per device."}
|
||||
)
|
||||
grad_accum_steps: int = field(
|
||||
default=1, metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
|
|
|
|||
|
|
@ -79,8 +79,7 @@ class GradientClippingCallback(TrainCallback):
|
|||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
_ = context
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class TrainContextBuilder:
|
|||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||
|
||||
cfg = self.config
|
||||
sampler_offset = context.iteration * cfg.batch_size
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.dataset,
|
||||
start_epoch=context.epoch,
|
||||
|
|
@ -79,7 +79,7 @@ class TrainContextBuilder:
|
|||
)
|
||||
context.dataloader = DataLoader(
|
||||
cfg.dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=cfg.pin_memory,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from itertools import batched
|
||||
from typing import List, Optional
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
|
|
@ -33,11 +32,6 @@ class Trainer:
|
|||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
]
|
||||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
return (
|
||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||
)
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
|
|
@ -45,49 +39,47 @@ class Trainer:
|
|||
method(context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
config = self.train_config
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._train_impl,
|
||||
backend=config.backend,
|
||||
world_size=config.nprocs,
|
||||
master_addr=config.master_addr,
|
||||
master_port=config.master_port,
|
||||
device_type=config.device_type,
|
||||
backend=cfg.backend,
|
||||
world_size=cfg.nprocs,
|
||||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||
context = self._build_context(checkpoint)
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
||||
try:
|
||||
context.model.train()
|
||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||
grad_accum_steps = cfg.grad_accum_steps
|
||||
|
||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||
for epoch in range(context.epoch, cfg.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for steps in batched(context.dataloader, accumulation_steps):
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
for batch in context.dataloader:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / grad_accum_steps
|
||||
stand_loss.backward()
|
||||
context.iteration += 1
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
step_batch_nums = len(steps)
|
||||
for batch in steps:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
context.iteration += 1
|
||||
if context.iteration % grad_accum_steps == 0:
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
stand_loss = loss / step_batch_nums
|
||||
stand_loss.backward()
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
self._call_callbacks("on_step_end", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,18 +42,20 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
|
||||
parser.add_argument(
|
||||
"--accumulation_steps",
|
||||
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_accum_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations between each optimizer step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of warmup steps for LR scheduler.",
|
||||
"--warmup_ratio",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Fraction of total steps used for LR warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||
|
|
@ -177,24 +179,25 @@ def create_scheduler(
|
|||
return SchedulerFactory.create(optimizer, **kwargs)
|
||||
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
return model.module.state_dict()
|
||||
|
||||
|
||||
def compute_total_steps(
|
||||
dataset_len: int,
|
||||
n_epoch: int,
|
||||
batch_size: int,
|
||||
batch_per_device: int,
|
||||
nprocs: int,
|
||||
accumulation_steps: int,
|
||||
grad_accum_steps: int,
|
||||
) -> int:
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
samples_per_replica = ceil_div(dataset_len, nprocs)
|
||||
batches_per_replica = ceil_div(samples_per_replica, batch_size)
|
||||
return ceil_div(batches_per_replica, accumulation_steps) * n_epoch
|
||||
|
||||
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
return model.module.state_dict()
|
||||
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
|
||||
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
|
||||
return total_steps
|
||||
|
||||
|
||||
def train(
|
||||
|
|
@ -203,11 +206,11 @@ def train(
|
|||
data_root_path: str,
|
||||
max_lr: float,
|
||||
n_epoch: int,
|
||||
batch_size: int,
|
||||
batch_per_device: int,
|
||||
start_epoch: int,
|
||||
start_batch: int,
|
||||
accumulation_steps: int,
|
||||
warmup_steps: int,
|
||||
grad_accum_steps: int,
|
||||
warmup_ratio: float,
|
||||
ckpt_interval: int,
|
||||
ckpt_dir: str,
|
||||
dpo_beta: float,
|
||||
|
|
@ -277,8 +280,10 @@ def train(
|
|||
)
|
||||
|
||||
total_steps = compute_total_steps(
|
||||
len(dataset), n_epoch, batch_size, nprocs, accumulation_steps
|
||||
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
|
||||
)
|
||||
warmup_steps = int(warmup_ratio * total_steps)
|
||||
|
||||
scheduler_fn = partial(
|
||||
create_scheduler,
|
||||
**{
|
||||
|
|
@ -296,11 +301,11 @@ def train(
|
|||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=ckpt_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
batch_per_device=batch_per_device,
|
||||
start_epoch=start_epoch,
|
||||
start_batch=start_batch,
|
||||
ckpt_interval=ckpt_interval,
|
||||
accumulation_steps=accumulation_steps,
|
||||
grad_accum_steps=grad_accum_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
num_workers=num_workers,
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ def create_train_config(
|
|||
device: str,
|
||||
strategy: str = "seq",
|
||||
n_epoch: int = 1,
|
||||
batch_size: int = 2,
|
||||
accumulation_steps: int = 1,
|
||||
batch_per_device: int = 2,
|
||||
grad_accum_steps: int = 1,
|
||||
max_grad_norm: float = 1.0,
|
||||
ckpt_interval: int = 5,
|
||||
random_seed: int = 42,
|
||||
|
|
@ -47,8 +47,8 @@ def create_train_config(
|
|||
device: Device type ("cuda" or "cpu")
|
||||
strategy: Training strategy type (default: "seq")
|
||||
n_epoch: Number of epochs (default: 1)
|
||||
batch_size: Batch size (default: 2)
|
||||
accumulation_steps: Gradient accumulation steps (default: 1)
|
||||
batch_per_device: Batch size per device (default: 2)
|
||||
grad_accum_steps: Gradient accumulation steps (default: 1)
|
||||
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
||||
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
||||
random_seed: Random seed for reproducibility (default: 42)
|
||||
|
|
@ -74,9 +74,9 @@ def create_train_config(
|
|||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=test_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
batch_per_device=batch_per_device,
|
||||
ckpt_interval=ckpt_interval,
|
||||
accumulation_steps=accumulation_steps,
|
||||
grad_accum_steps=grad_accum_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
device_type=device,
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=3,
|
||||
accumulation_steps=1,
|
||||
grad_accum_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"],
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
dataset=early_stopping_dataset,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=1,
|
||||
accumulation_steps=2,
|
||||
grad_accum_steps=2,
|
||||
random_seed=np.random.randint(1e4),
|
||||
device_type=base_test_env["device"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
|||
"""Test training with different batch sizes"""
|
||||
batch_sizes = [1, 2, 4, 8]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for batch_per_device in batch_sizes:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
batch_size=batch_size,
|
||||
batch_per_device=batch_per_device,
|
||||
)
|
||||
|
||||
assert train_config.batch_size == batch_size
|
||||
assert train_config.batch_per_device == batch_per_device
|
||||
|
||||
|
||||
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
||||
"""Test training with different gradient accumulation steps"""
|
||||
accumulation_steps_list = [1, 2, 4]
|
||||
grad_accum_steps_list = [1, 2, 4]
|
||||
|
||||
for accumulation_steps in accumulation_steps_list:
|
||||
for grad_accum_steps in grad_accum_steps_list:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
batch_size=2,
|
||||
accumulation_steps=accumulation_steps,
|
||||
batch_per_device=2,
|
||||
grad_accum_steps=grad_accum_steps,
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
|
||||
assert train_config.accumulation_steps == accumulation_steps
|
||||
assert train_config.grad_accum_steps == grad_accum_steps
|
||||
|
||||
|
||||
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
||||
"""Test training with memory-efficient configurations"""
|
||||
# Test with smaller batch sizes and gradient checkpointing
|
||||
small_batch_configs = [
|
||||
{"batch_size": 1, "accumulation_steps": 8},
|
||||
{"batch_size": 2, "accumulation_steps": 4},
|
||||
{"batch_size": 4, "accumulation_steps": 2},
|
||||
{"batch_per_device": 1, "grad_accum_steps": 8},
|
||||
{"batch_per_device": 2, "grad_accum_steps": 4},
|
||||
{"batch_per_device": 4, "grad_accum_steps": 2},
|
||||
]
|
||||
|
||||
for config in small_batch_configs:
|
||||
|
|
@ -54,8 +54,9 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
|||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
batch_size=config["batch_size"],
|
||||
accumulation_steps=config["accumulation_steps"],
|
||||
batch_per_device=config["batch_per_device"],
|
||||
grad_accum_steps=config["grad_accum_steps"],
|
||||
)
|
||||
|
||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||
assert train_config.grad_accum_steps == config["grad_accum_steps"]
|
||||
assert train_config.batch_per_device == config["batch_per_device"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue