From d7a7f570ed0430c73567148afbfe90cd49da21fe Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 16 May 2026 21:27:35 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E8=AE=AD=E7=BB=83=E5=BE=AA?= =?UTF-8?q?=E7=8E=AF=E6=94=B9=E4=B8=BA=E4=B8=A4=E9=87=8D=E8=BF=AD=E4=BB=A3?= =?UTF-8?q?=E5=B9=B6=E7=BB=9F=E4=B8=80=E5=8F=82=E6=95=B0=E5=91=BD=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 训练循环从三重(epoch→batched→batch)改为二重(epoch→batch) - batch_size → batch_per_device, accumulation_steps → grad_accum_steps - scheduler 移入 step block 对齐 optimizer 更新步 - GradientClippingCallback 改用 on_step_begin 避免零梯度裁剪 - 移除 _train_impl 误导性的 -> Checkpoint 标注 - total_steps 修除为向下取整并精简为一行 - warmup_steps 改为 warmup_ratio (默认0.05) --- README.md | 30 +++++--- assets/docs/README-zh-CN.md | 30 +++++--- assets/docs/architecture.md | 39 +++++----- assets/docs/params.md | 111 ++++++--------------------- assets/docs/training.md | 66 +++++++++------- astrai/config/train_config.py | 6 +- astrai/trainer/train_callback.py | 3 +- astrai/trainer/train_context.py | 4 +- astrai/trainer/trainer.py | 60 +++++++-------- scripts/tools/train.py | 49 ++++++------ tests/trainer/conftest.py | 12 +-- tests/trainer/test_callbacks.py | 4 +- tests/trainer/test_early_stopping.py | 4 +- tests/trainer/test_trainer.py | 29 +++---- 14 files changed, 210 insertions(+), 237 deletions(-) diff --git a/README.md b/README.md index a268082..8d6e1bf 100644 --- a/README.md +++ b/README.md @@ -78,15 +78,27 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i #### Train a Model ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \ - --train_type seq \ - --data_root_path /path/to/dataset \ - --param_path /path/to/model \ - --batch_size 4 \ - --accumulation_steps 8 \ - --max_lr 3e-4 \ - --warmup_steps 1000 \ - --n_epoch 1 +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +nohup python scripts/tools/train.py \ + --nprocs=4 \ + --train_type=sft \ + --data_root_path=/path/to/dataset \ + --param_path=/path/to/model \ + --batch_per_device=4 \ + --grad_accum_steps=8 \ + --warmup_ratio=0.05 \ + --max_lr=1e-4 \ + --max_grad_norm=1.0 \ + --adamw_beta1=0.99 \ + --adamw_beta2=0.95 \ + --adamw_weight_decay=1e-5 \ + --window_size=2048 \ + --ckpt_interval=10000 \ + --ckpt_dir=./checkpoint \ + --random_seed=3407 \ + --label_smoothing=0.1 \ + > out.log 2> err.log & ``` Full reference at [Parameter Guide](assets/docs/params.md). diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index 13cf22b..af6ca8b 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -84,15 +84,27 @@ python scripts/demo/download.py #### 训练模型 ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \ - --train_type seq \ - --data_root_path /path/to/dataset \ - --param_path /path/to/model \ - --batch_size 4 \ - --accumulation_steps 8 \ - --max_lr 3e-4 \ - --warmup_steps 1000 \ - --n_epoch 1 +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +nohup python scripts/tools/train.py \ + --nprocs=4 \ + --train_type=sft \ + --data_root_path=/path/to/dataset \ + --param_path=/path/to/model \ + --batch_per_device=4 \ + --grad_accum_steps=8 \ + --warmup_ratio=0.05 \ + --max_lr=1e-4 \ + --max_grad_norm=1.0 \ + --adamw_beta1=0.99 \ + --adamw_beta2=0.95 \ + --adamw_weight_decay=1e-5 \ + --window_size=2048 \ + --ckpt_interval=10000 \ + --ckpt_dir=./checkpoint \ + --random_seed=3407 \ + --label_smoothing=0.1 \ + > out.log 2> err.log & ``` 完整参数列表见[参数说明](./params.md)。 diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index 805f7c4..b1208ea 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -30,6 +30,9 @@ classDiagram +int n_shared_experts +int n_activated_experts +str moe_topk_method + +Optional[int] kv_lora_rank + +Optional[int] qk_nope_head_dim + +Optional[int] qk_rope_head_dim +load(config_path) ModelConfig +save(config_path) } @@ -41,8 +44,8 @@ classDiagram +Callable optimizer_fn +Callable scheduler_fn +int n_epoch - +int batch_size - +int accumulation_steps + +int batch_per_device + +int grad_accum_steps +float max_grad_norm +int start_epoch +int start_batch @@ -69,7 +72,7 @@ classDiagram class BaseDataset { +int window_size +int stride - +BaseStorage storage + +Optional[BaseStorage] storage +load(load_path, storage_type, tokenizer) +__getitem__(index) +__len__() @@ -126,8 +129,8 @@ classDiagram } class ResumableDistributedSampler { - +int start_epoch - +int start_iter + +int epoch + +int iter } class DatasetFactory { @@ -155,7 +158,7 @@ classDiagram +Registry _registry +register(model_type) decorator +get_component_class(model_type) Type - +from_pretrained(path, disable_random_init) nn.Module + +from_pretrained(path, disable_random_init, strict) nn.Module +save_pretrained(save_directory) +to(*args, **kwargs) Self } @@ -167,7 +170,7 @@ classDiagram +ModuleList layers +RMSNorm norm +Linear lm_head - +forward(input_ids, input_mask, paged_cache, position_ids) Dict + +forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor] +load_state_dict(state_dict) +state_dict() } @@ -185,6 +188,7 @@ classDiagram +int n_kv_heads +int head_dim +int n_rep + +int layer_id +bool use_qk_norm +bool use_gated_attention +Linear q_proj, k_proj, v_proj, o_proj @@ -201,6 +205,7 @@ classDiagram +int qk_nope_head_dim +int qk_rope_head_dim +int n_rep + +int layer_id +bool use_gated_attention +Linear q_proj, kv_a_proj, kv_b_proj +Linear o_proj @@ -215,6 +220,7 @@ classDiagram } class DeepSeekMoE { + +int dim +int n_routed_experts +int n_shared_experts +int n_activated_experts @@ -236,6 +242,7 @@ classDiagram class RMSNorm { +Parameter weight +float norm_eps + +tuple normalized_shape +forward(x) Tensor } @@ -299,7 +306,6 @@ classDiagram +TrainConfig train_config +List[TrainCallback] callbacks +train(checkpoint) - +_build_context(checkpoint) TrainContext +_get_default_callbacks() List[TrainCallback] } @@ -324,7 +330,7 @@ classDiagram } class BaseStrategy { - +nn.Module model + +Union[Callable, nn.Module] model +str device +compute_loss(batch) Tensor } @@ -332,7 +338,7 @@ classDiagram class StrategyFactory { +Registry _registry +register(name) decorator - +create(model, train_type, device, **kwargs) BaseStrategy + +create(train_type, model, device, **kwargs) BaseStrategy } class SEQStrategy { @@ -400,7 +406,7 @@ classDiagram class GradientClippingCallback { +float max_grad_norm - +on_step_end(context) + +on_step_begin(context) } class CheckpointCallback { @@ -459,10 +465,7 @@ classDiagram +TaskManager _task_mgr +bool _running +Thread _loop_thread - +int max_batch_size +int max_seq_len - +int max_prompt_len - +int page_size +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +remove_task(task_id) +start() @@ -500,10 +503,7 @@ classDiagram } class Storage { - +int n_layers +int page_size - +int head_dim - +int n_kv_heads +Tensor k_cache +Tensor v_cache +write(layer_id, page_table, start_pos, k, v) @@ -675,7 +675,6 @@ classDiagram } class AnthropicHandler { - +List[str] stop_sequences +build_prompt() str +create_response_id() str +on_token(ctx, token, stop_checker) Optional[str] @@ -704,7 +703,7 @@ classDiagram namespace parallel { class Functions { - +spawn_parallel_fn(fn, nprocs) + +spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) +get_current_device() str +get_world_size() int @@ -878,4 +877,4 @@ classDiagram 8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops -> Document Update Time: 2026-05-15 +> Document Update Time: 2026-05-16 diff --git a/assets/docs/params.md b/assets/docs/params.md index 2b09de5..2e1c54b 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -10,14 +10,14 @@ | `--data_root_path` | Dataset root directory | required | | `--param_path` | Model parameters or checkpoint path | required | | `--n_epoch` | Total training epochs | 1 | -| `--batch_size` | Batch size | 1 | -| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 | +| `--batch_per_device` | Batch size per device | 1 | +| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 | ### Learning Rate Scheduling | Parameter | Description | Default | |-----------|-------------|---------| -| `--warmup_steps` | Warmup steps | 1000 | +| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 | | `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 | | `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 | @@ -69,90 +69,29 @@ ### Usage Example ```bash -python scripts/tools/train.py \ - --train_type seq \ - --data_root_path /path/to/dataset \ - --param_path /path/to/model \ - --n_epoch 3 \ - --batch_size 4 \ - --accumulation_steps 8 \ - --max_lr 3e-4 \ - --warmup_steps 2000 \ - --max_grad_norm 1.0 \ - --ckpt_interval 5000 \ - --ckpt_dir ./checkpoints \ - --num_workers 4 \ - --nprocs 1 \ - --device_type cuda +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +nohup python scripts/tools/train.py \ + --nprocs=4 \ + --train_type=sft \ + --data_root_path=/path/to/dataset \ + --param_path=/path/to/model \ + --batch_per_device=4 \ + --grad_accum_steps=8 \ + --warmup_ratio=0.05 \ + --max_lr=1e-4 \ + --max_grad_norm=1.0 \ + --adamw_beta1=0.99 \ + --adamw_beta2=0.95 \ + --adamw_weight_decay=1e-5 \ + --window_size=2048 \ + --ckpt_interval=10000 \ + --ckpt_dir=./checkpoint \ + --random_seed=3407 \ + --label_smoothing=0.1 \ + > out.log 2> err.log & ``` --- -## Generation Parameters - -### GenerationRequest Parameters - -| Parameter | Description | Default Value | -|-----------|-------------|---------------| -| `messages` | List of message dictionaries (role, content) | required | -| `temperature` | Sampling temperature (higher = more random) | 1.0 | -| `top_p` | Nucleus sampling threshold | 1.0 | -| `top_k` | Top-k sampling count | 50 | -| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) | -| `stream` | Whether to stream output | False | - -### Usage Example - -```python -import torch -from astrai.model import AutoModel -from astrai.tokenize import AutoTokenizer -from astrai.inference import InferenceEngine, GenerationRequest - -# Load model using AutoModel -model = AutoModel.from_pretrained("your_model_dir") - -# Load tokenizer -tokenizer = AutoTokenizer.from_pretrained("your_model_dir") - -# Create engine with separate model and tokenizer -engine = InferenceEngine( - model=model, - tokenizer=tokenizer, -) - -# Build request with messages format -request = GenerationRequest( - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}, - ], - temperature=0.8, - top_p=0.95, - top_k=50, - max_tokens=None, -) - -# Generate (streaming) -for token in engine.generate_with_request(request): - print(token, end="", flush=True) - -# Or use simple generate interface -result = engine.generate( - prompt="Hello", - stream=False, - max_tokens=1024, - temperature=0.8, - top_p=0.95, - top_k=50, -) -``` - -### Generation Modes - -| Mode | Description | -|------|-------------| -| `stream=True` | Streaming output, yields token by token | -| `stream=False` | Non-streaming output, returns complete result | - -> Document Update Time: 2026-05-15 \ No newline at end of file +> Document Update Time: 2026-05-16 \ No newline at end of file diff --git a/assets/docs/training.md b/assets/docs/training.md index d17736d..a7aa1b0 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi ## Training Loop -Nested loop: **epoch** → **step** (accumulation window) → **batch**. +Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches. ``` on_train_begin on_epoch_begin - for steps in batched(dataloader, accumulation_steps): - on_step_begin - step_batch_nums = len(steps) - for batch in steps: - on_batch_begin - loss = strategy(batch) - (loss / step_batch_nums).backward() - iteration += 1 - on_batch_end - on_step_end - optimizer.step() - optimizer.zero_grad() - scheduler.step() + for batch in dataloader: + on_batch_begin + loss = strategy(batch) + (loss / grad_accum_steps).backward() + iteration += 1 + on_batch_end + + if iteration % grad_accum_steps == 0: + on_step_begin + optimizer.step() + optimizer.zero_grad() + on_step_end + scheduler.step() on_epoch_end on_train_end ``` @@ -91,9 +91,9 @@ on_train_end | Hook | Fires | Default callback | |------|-------|-----------------| -| `on_step_end` | Every accumulation window | `GradientClippingCallback` | +| `on_step_begin` | Every accumulation window | `GradientClippingCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | -| `on_train_end` | Training ends | `CheckpointCallback` (final save) | +| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`. @@ -162,7 +162,7 @@ Checkpoint(state_dict, epoch, iteration, extra) └── load(save_dir) broadcasts metadata from rank-0 ``` -Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data. +Optimizer/scheduler state persisted by default via `Checkpoint.extra`. ## TrainContextBuilder (Builder Pattern) @@ -183,17 +183,29 @@ context = ( ## Training CLI ```bash -python scripts/tools/train.py \ - --train_type seq \ - --data_root_path /path/to/data \ - --param_path /path/to/model \ - --batch_size 4 \ - --accumulation_steps 8 \ - --max_lr 3e-4 \ - --warmup_steps 1000 \ - --n_epoch 1 +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +nohup python scripts/tools/train.py \ + --nprocs=4 \ + --train_type=sft \ + --data_root_path=/path/to/dataset \ + --param_path=/path/to/model \ + --batch_per_device=4 \ + --grad_accum_steps=8 \ + --warmup_ratio=0.05 \ + --max_lr=1e-4 \ + --max_grad_norm=1.0 \ + --adamw_beta1=0.99 \ + --adamw_beta2=0.95 \ + --adamw_weight_decay=1e-5 \ + --window_size=2048 \ + --ckpt_interval=10000 \ + --ckpt_dir=./checkpoint \ + --random_seed=3407 \ + --label_smoothing=0.1 \ + > out.log 2> err.log & ``` Full parameter reference at [params.md](params.md). -> Document Update Time: 2026-05-15 +> Document Update Time: 2026-05-16 diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index a41a23a..fdfdb9b 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -20,8 +20,10 @@ class TrainConfig: default=None, metadata={"help": "Scheduler factory for training."} ) n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) - batch_size: int = field(default=4, metadata={"help": "Batch size for training."}) - accumulation_steps: int = field( + batch_per_device: int = field( + default=4, metadata={"help": "Batch size per device."} + ) + grad_accum_steps: int = field( default=1, metadata={"help": "Number of iterations between steps."} ) max_grad_norm: float = field( diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 07ab4eb..0fcef47 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -79,8 +79,7 @@ class GradientClippingCallback(TrainCallback): def __init__(self, max_grad_norm: float): self.max_grad_norm = max_grad_norm - def on_step_end(self, context: TrainContext): - _ = context + def on_step_begin(self, context: TrainContext): clip_grad_norm_(context.model.parameters(), self.max_grad_norm) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index a81d23a..0350144 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -70,7 +70,7 @@ class TrainContextBuilder: context.scheduler = self.config.scheduler_fn(context.optimizer) cfg = self.config - sampler_offset = context.iteration * cfg.batch_size + sampler_offset = context.iteration * cfg.batch_per_device sampler = ResumableDistributedSampler( data_source=cfg.dataset, start_epoch=context.epoch, @@ -79,7 +79,7 @@ class TrainContextBuilder: ) context.dataloader = DataLoader( cfg.dataset, - batch_size=cfg.batch_size, + batch_size=cfg.batch_per_device, sampler=sampler, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index d1f5c57..2fa385b 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -1,5 +1,4 @@ import logging -from itertools import batched from typing import List, Optional from astrai.config import TrainConfig @@ -33,11 +32,6 @@ class Trainer: CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), ] - def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: - return ( - TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build() - ) - def _call_callbacks(self, method_name: str, context: TrainContext): for callback in self.callbacks: method = getattr(callback, method_name, None) @@ -45,49 +39,47 @@ class Trainer: method(context) def train(self, checkpoint: Optional[Checkpoint] = None): - config = self.train_config + cfg = self.train_config spawn_parallel_fn( self._train_impl, - backend=config.backend, - world_size=config.nprocs, - master_addr=config.master_addr, - master_port=config.master_port, - device_type=config.device_type, + backend=cfg.backend, + world_size=cfg.nprocs, + master_addr=cfg.master_addr, + master_port=cfg.master_port, + device_type=cfg.device_type, checkpoint=checkpoint, ) - def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: - context = self._build_context(checkpoint) + def _train_impl(self, checkpoint: Optional[Checkpoint] = None): + cfg = self.train_config + context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build() self._call_callbacks("on_train_begin", context) try: context.model.train() - accumulation_steps = max(self.train_config.accumulation_steps, 1) + grad_accum_steps = cfg.grad_accum_steps - for epoch in range(context.epoch, self.train_config.n_epoch): + for epoch in range(context.epoch, cfg.n_epoch): context.epoch = epoch self._call_callbacks("on_epoch_begin", context) - for steps in batched(context.dataloader, accumulation_steps): - self._call_callbacks("on_step_begin", context) + for batch in context.dataloader: + self._call_callbacks("on_batch_begin", context) + loss = context.strategy(batch) + context.loss = loss.item() + stand_loss = loss / grad_accum_steps + stand_loss.backward() + context.iteration += 1 + self._call_callbacks("on_batch_end", context) - step_batch_nums = len(steps) - for batch in steps: - self._call_callbacks("on_batch_begin", context) - loss = context.strategy(batch) - context.loss = loss.item() - context.iteration += 1 + if context.iteration % grad_accum_steps == 0: + self._call_callbacks("on_step_begin", context) + context.optimizer.step() + context.optimizer.zero_grad() + self._call_callbacks("on_step_end", context) - stand_loss = loss / step_batch_nums - stand_loss.backward() - self._call_callbacks("on_batch_end", context) - - self._call_callbacks("on_step_end", context) - context.optimizer.step() - context.optimizer.zero_grad() - - if context.scheduler: - context.scheduler.step() + if context.scheduler: + context.scheduler.step() self._call_callbacks("on_epoch_end", context) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index fe6c6cb..795aa0e 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -42,18 +42,20 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--n_epoch", type=int, default=1, help="Number of epochs to train." ) - parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.") parser.add_argument( - "--accumulation_steps", + "--batch_per_device", type=int, default=1, help="Batch size per GPU." + ) + parser.add_argument( + "--grad_accum_steps", type=int, default=1, help="Number of iterations between each optimizer step.", ) parser.add_argument( - "--warmup_steps", - type=int, - default=1000, - help="Number of warmup steps for LR scheduler.", + "--warmup_ratio", + type=float, + default=0.05, + help="Fraction of total steps used for LR warmup.", ) parser.add_argument( "--max_lr", type=float, default=3e-4, help="Max learning rate for training." @@ -177,24 +179,25 @@ def create_scheduler( return SchedulerFactory.create(optimizer, **kwargs) -def ceil_div(a: int, b: int) -> int: - return (a + b - 1) // b +def prepare_checkpoint(model: nn.Module) -> dict: + return model.module.state_dict() def compute_total_steps( dataset_len: int, n_epoch: int, - batch_size: int, + batch_per_device: int, nprocs: int, - accumulation_steps: int, + grad_accum_steps: int, ) -> int: + + def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + samples_per_replica = ceil_div(dataset_len, nprocs) - batches_per_replica = ceil_div(samples_per_replica, batch_size) - return ceil_div(batches_per_replica, accumulation_steps) * n_epoch - - -def prepare_checkpoint(model: nn.Module) -> dict: - return model.module.state_dict() + batches_per_replica = ceil_div(samples_per_replica, batch_per_device) + total_steps = (batches_per_replica // grad_accum_steps) * n_epoch + return total_steps def train( @@ -203,11 +206,11 @@ def train( data_root_path: str, max_lr: float, n_epoch: int, - batch_size: int, + batch_per_device: int, start_epoch: int, start_batch: int, - accumulation_steps: int, - warmup_steps: int, + grad_accum_steps: int, + warmup_ratio: float, ckpt_interval: int, ckpt_dir: str, dpo_beta: float, @@ -277,8 +280,10 @@ def train( ) total_steps = compute_total_steps( - len(dataset), n_epoch, batch_size, nprocs, accumulation_steps + len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps ) + warmup_steps = int(warmup_ratio * total_steps) + scheduler_fn = partial( create_scheduler, **{ @@ -296,11 +301,11 @@ def train( scheduler_fn=scheduler_fn, ckpt_dir=ckpt_dir, n_epoch=n_epoch, - batch_size=batch_size, + batch_per_device=batch_per_device, start_epoch=start_epoch, start_batch=start_batch, ckpt_interval=ckpt_interval, - accumulation_steps=accumulation_steps, + grad_accum_steps=grad_accum_steps, max_grad_norm=max_grad_norm, random_seed=random_seed, num_workers=num_workers, diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py index 4efa745..265ea5f 100644 --- a/tests/trainer/conftest.py +++ b/tests/trainer/conftest.py @@ -31,8 +31,8 @@ def create_train_config( device: str, strategy: str = "seq", n_epoch: int = 1, - batch_size: int = 2, - accumulation_steps: int = 1, + batch_per_device: int = 2, + grad_accum_steps: int = 1, max_grad_norm: float = 1.0, ckpt_interval: int = 5, random_seed: int = 42, @@ -47,8 +47,8 @@ def create_train_config( device: Device type ("cuda" or "cpu") strategy: Training strategy type (default: "seq") n_epoch: Number of epochs (default: 1) - batch_size: Batch size (default: 2) - accumulation_steps: Gradient accumulation steps (default: 1) + batch_per_device: Batch size per device (default: 2) + grad_accum_steps: Gradient accumulation steps (default: 1) max_grad_norm: Maximum gradient norm for clipping (default: 1.0) ckpt_interval: Checkpoint save interval in iterations (default: 5) random_seed: Random seed for reproducibility (default: 42) @@ -74,9 +74,9 @@ def create_train_config( scheduler_fn=scheduler_fn, ckpt_dir=test_dir, n_epoch=n_epoch, - batch_size=batch_size, + batch_per_device=batch_per_device, ckpt_interval=ckpt_interval, - accumulation_steps=accumulation_steps, + grad_accum_steps=grad_accum_steps, max_grad_norm=max_grad_norm, random_seed=random_seed, device_type=device, diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 0238eab..f7ae8ad 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset): scheduler_fn=scheduler_fn, ckpt_dir=base_test_env["test_dir"], n_epoch=1, - batch_size=2, + batch_per_device=2, ckpt_interval=3, - accumulation_steps=1, + grad_accum_steps=1, max_grad_norm=1.0, random_seed=42, device_type=base_test_env["device"], diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index c2d84c5..83e431d 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): dataset=early_stopping_dataset, ckpt_dir=base_test_env["test_dir"], n_epoch=2, - batch_size=2, + batch_per_device=2, ckpt_interval=1, - accumulation_steps=2, + grad_accum_steps=2, random_seed=np.random.randint(1e4), device_type=base_test_env["device"], ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bb6520d..f51cb40 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto """Test training with different batch sizes""" batch_sizes = [1, 2, 4, 8] - for batch_size in batch_sizes: + for batch_per_device in batch_sizes: train_config = train_config_factory( model=base_test_env["model"], dataset=random_dataset, test_dir=base_test_env["test_dir"], device=base_test_env["device"], - batch_size=batch_size, + batch_per_device=batch_per_device, ) - assert train_config.batch_size == batch_size + assert train_config.batch_per_device == batch_per_device def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory): """Test training with different gradient accumulation steps""" - accumulation_steps_list = [1, 2, 4] + grad_accum_steps_list = [1, 2, 4] - for accumulation_steps in accumulation_steps_list: + for grad_accum_steps in grad_accum_steps_list: train_config = train_config_factory( model=base_test_env["model"], dataset=random_dataset, test_dir=base_test_env["test_dir"], device=base_test_env["device"], - batch_size=2, - accumulation_steps=accumulation_steps, + batch_per_device=2, + grad_accum_steps=grad_accum_steps, ) trainer = Trainer(train_config) trainer.train() - assert train_config.accumulation_steps == accumulation_steps + assert train_config.grad_accum_steps == grad_accum_steps def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory): """Test training with memory-efficient configurations""" # Test with smaller batch sizes and gradient checkpointing small_batch_configs = [ - {"batch_size": 1, "accumulation_steps": 8}, - {"batch_size": 2, "accumulation_steps": 4}, - {"batch_size": 4, "accumulation_steps": 2}, + {"batch_per_device": 1, "grad_accum_steps": 8}, + {"batch_per_device": 2, "grad_accum_steps": 4}, + {"batch_per_device": 4, "grad_accum_steps": 2}, ] for config in small_batch_configs: @@ -54,8 +54,9 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f dataset=random_dataset, test_dir=base_test_env["test_dir"], device=base_test_env["device"], - batch_size=config["batch_size"], - accumulation_steps=config["accumulation_steps"], + batch_per_device=config["batch_per_device"], + grad_accum_steps=config["grad_accum_steps"], ) - assert train_config.accumulation_steps == config["accumulation_steps"] + assert train_config.grad_accum_steps == config["grad_accum_steps"] + assert train_config.batch_per_device == config["batch_per_device"]