refactor: 训练循环改为两重迭代并统一参数命名

- 训练循环从三重(epoch→batched→batch)改为二重(epoch→batch)
- batch_size → batch_per_device, accumulation_steps → grad_accum_steps
- scheduler 移入 step block 对齐 optimizer 更新步
- GradientClippingCallback 改用 on_step_begin 避免零梯度裁剪
- 移除 _train_impl 误导性的 -> Checkpoint 标注
- total_steps 修除为向下取整并精简为一行
- warmup_steps 改为 warmup_ratio (默认0.05)
This commit is contained in:
ViperEkura 2026-05-16 21:27:35 +08:00
parent 7dea929788
commit d7a7f570ed
14 changed files with 210 additions and 237 deletions

View File

@ -78,15 +78,27 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
#### Train a Model
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=sft \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.99 \
--adamw_beta2=0.95 \
--adamw_weight_decay=1e-5 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.1 \
> out.log 2> err.log &
```
Full reference at [Parameter Guide](assets/docs/params.md).

View File

@ -84,15 +84,27 @@ python scripts/demo/download.py
#### 训练模型
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=sft \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.99 \
--adamw_beta2=0.95 \
--adamw_weight_decay=1e-5 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.1 \
> out.log 2> err.log &
```
完整参数列表见[参数说明](./params.md)。

View File

@ -30,6 +30,9 @@ classDiagram
+int n_shared_experts
+int n_activated_experts
+str moe_topk_method
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+load(config_path) ModelConfig
+save(config_path)
}
@ -41,8 +44,8 @@ classDiagram
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_size
+int accumulation_steps
+int batch_per_device
+int grad_accum_steps
+float max_grad_norm
+int start_epoch
+int start_batch
@ -69,7 +72,7 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+BaseStorage storage
+Optional[BaseStorage] storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
@ -126,8 +129,8 @@ classDiagram
}
class ResumableDistributedSampler {
+int start_epoch
+int start_iter
+int epoch
+int iter
}
class DatasetFactory {
@ -155,7 +158,7 @@ classDiagram
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module
+from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
@ -167,7 +170,7 @@ classDiagram
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict)
+state_dict()
}
@ -185,6 +188,7 @@ classDiagram
+int n_kv_heads
+int head_dim
+int n_rep
+int layer_id
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj
@ -201,6 +205,7 @@ classDiagram
+int qk_nope_head_dim
+int qk_rope_head_dim
+int n_rep
+int layer_id
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
@ -215,6 +220,7 @@ classDiagram
}
class DeepSeekMoE {
+int dim
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
@ -236,6 +242,7 @@ classDiagram
class RMSNorm {
+Parameter weight
+float norm_eps
+tuple normalized_shape
+forward(x) Tensor
}
@ -299,7 +306,6 @@ classDiagram
+TrainConfig train_config
+List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback]
}
@ -324,7 +330,7 @@ classDiagram
}
class BaseStrategy {
+nn.Module model
+Union[Callable, nn.Module] model
+str device
+compute_loss(batch) Tensor
}
@ -332,7 +338,7 @@ classDiagram
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(model, train_type, device, **kwargs) BaseStrategy
+create(train_type, model, device, **kwargs) BaseStrategy
}
class SEQStrategy {
@ -400,7 +406,7 @@ classDiagram
class GradientClippingCallback {
+float max_grad_norm
+on_step_end(context)
+on_step_begin(context)
}
class CheckpointCallback {
@ -459,10 +465,7 @@ classDiagram
+TaskManager _task_mgr
+bool _running
+Thread _loop_thread
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
@ -500,10 +503,7 @@ classDiagram
}
class Storage {
+int n_layers
+int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache
+Tensor v_cache
+write(layer_id, page_table, start_pos, k, v)
@ -675,7 +675,6 @@ classDiagram
}
class AnthropicHandler {
+List[str] stop_sequences
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
@ -704,7 +703,7 @@ classDiagram
namespace parallel {
class Functions {
+spawn_parallel_fn(fn, nprocs)
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
@ -878,4 +877,4 @@ classDiagram
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-16

View File

@ -10,14 +10,14 @@
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
| `--batch_per_device` | Batch size per device | 1 |
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 |
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
@ -69,90 +69,29 @@
### Usage Example
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=sft \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.99 \
--adamw_beta2=0.95 \
--adamw_weight_decay=1e-5 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.1 \
> out.log 2> err.log &
```
---
## Generation Parameters
### GenerationRequest Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `messages` | List of message dictionaries (role, content) | required |
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
| `stream` | Whether to stream output | False |
### Usage Example
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
# Build request with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_tokens=None,
)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
# Or use simple generate interface
result = engine.generate(
prompt="Hello",
stream=False,
max_tokens=1024,
temperature=0.8,
top_p=0.95,
top_k=50,
)
```
### Generation Modes
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-16

View File

@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
## Training Loop
Nested loop: **epoch****step** (accumulation window) → **batch**.
Two-level loop: **epoch****batch**. Optimizer step fires every `grad_accum_steps` batches.
```
on_train_begin
on_epoch_begin
for steps in batched(dataloader, accumulation_steps):
on_step_begin
step_batch_nums = len(steps)
for batch in steps:
on_batch_begin
loss = strategy(batch)
(loss / step_batch_nums).backward()
iteration += 1
on_batch_end
on_step_end
optimizer.step()
optimizer.zero_grad()
scheduler.step()
for batch in dataloader:
on_batch_begin
loss = strategy(batch)
(loss / grad_accum_steps).backward()
iteration += 1
on_batch_end
if iteration % grad_accum_steps == 0:
on_step_begin
optimizer.step()
optimizer.zero_grad()
on_step_end
scheduler.step()
on_epoch_end
on_train_end
```
@ -91,9 +91,9 @@ on_train_end
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
@ -162,7 +162,7 @@ Checkpoint(state_dict, epoch, iteration, extra)
└── load(save_dir) broadcasts metadata from rank-0
```
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
## TrainContextBuilder (Builder Pattern)
@ -183,17 +183,29 @@ context = (
## Training CLI
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/data \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=sft \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.99 \
--adamw_beta2=0.95 \
--adamw_weight_decay=1e-5 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.1 \
> out.log 2> err.log &
```
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-16

View File

@ -20,8 +20,10 @@ class TrainConfig:
default=None, metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field(
batch_per_device: int = field(
default=4, metadata={"help": "Batch size per device."}
)
grad_accum_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(

View File

@ -79,8 +79,7 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_end(self, context: TrainContext):
_ = context
def on_step_begin(self, context: TrainContext):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)

View File

@ -70,7 +70,7 @@ class TrainContextBuilder:
context.scheduler = self.config.scheduler_fn(context.optimizer)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
start_epoch=context.epoch,
@ -79,7 +79,7 @@ class TrainContextBuilder:
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
batch_size=cfg.batch_per_device,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,

View File

@ -1,5 +1,4 @@
import logging
from itertools import batched
from typing import List, Optional
from astrai.config import TrainConfig
@ -33,11 +32,6 @@ class Trainer:
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
@ -45,49 +39,47 @@ class Trainer:
method(context)
def train(self, checkpoint: Optional[Checkpoint] = None):
config = self.train_config
cfg = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=config.backend,
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
device_type=config.device_type,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
checkpoint=checkpoint,
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
accumulation_steps = max(self.train_config.accumulation_steps, 1)
grad_accum_steps = cfg.grad_accum_steps
for epoch in range(context.epoch, self.train_config.n_epoch):
for epoch in range(context.epoch, cfg.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for steps in batched(context.dataloader, accumulation_steps):
self._call_callbacks("on_step_begin", context)
for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / grad_accum_steps
stand_loss.backward()
context.iteration += 1
self._call_callbacks("on_batch_end", context)
step_batch_nums = len(steps)
for batch in steps:
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
context.iteration += 1
if context.iteration % grad_accum_steps == 0:
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
stand_loss = loss / step_batch_nums
stand_loss.backward()
self._call_callbacks("on_batch_end", context)
self._call_callbacks("on_step_end", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context)

View File

@ -42,18 +42,20 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
parser.add_argument(
"--accumulation_steps",
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
)
parser.add_argument(
"--grad_accum_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=1000,
help="Number of warmup steps for LR scheduler.",
"--warmup_ratio",
type=float,
default=0.05,
help="Fraction of total steps used for LR warmup.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -177,24 +179,25 @@ def create_scheduler(
return SchedulerFactory.create(optimizer, **kwargs)
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
def compute_total_steps(
dataset_len: int,
n_epoch: int,
batch_size: int,
batch_per_device: int,
nprocs: int,
accumulation_steps: int,
grad_accum_steps: int,
) -> int:
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_size)
return ceil_div(batches_per_replica, accumulation_steps) * n_epoch
def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
return total_steps
def train(
@ -203,11 +206,11 @@ def train(
data_root_path: str,
max_lr: float,
n_epoch: int,
batch_size: int,
batch_per_device: int,
start_epoch: int,
start_batch: int,
accumulation_steps: int,
warmup_steps: int,
grad_accum_steps: int,
warmup_ratio: float,
ckpt_interval: int,
ckpt_dir: str,
dpo_beta: float,
@ -277,8 +280,10 @@ def train(
)
total_steps = compute_total_steps(
len(dataset), n_epoch, batch_size, nprocs, accumulation_steps
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
)
warmup_steps = int(warmup_ratio * total_steps)
scheduler_fn = partial(
create_scheduler,
**{
@ -296,11 +301,11 @@ def train(
scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_size=batch_size,
batch_per_device=batch_per_device,
start_epoch=start_epoch,
start_batch=start_batch,
ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps,
grad_accum_steps=grad_accum_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=num_workers,

View File

@ -31,8 +31,8 @@ def create_train_config(
device: str,
strategy: str = "seq",
n_epoch: int = 1,
batch_size: int = 2,
accumulation_steps: int = 1,
batch_per_device: int = 2,
grad_accum_steps: int = 1,
max_grad_norm: float = 1.0,
ckpt_interval: int = 5,
random_seed: int = 42,
@ -47,8 +47,8 @@ def create_train_config(
device: Device type ("cuda" or "cpu")
strategy: Training strategy type (default: "seq")
n_epoch: Number of epochs (default: 1)
batch_size: Batch size (default: 2)
accumulation_steps: Gradient accumulation steps (default: 1)
batch_per_device: Batch size per device (default: 2)
grad_accum_steps: Gradient accumulation steps (default: 1)
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
ckpt_interval: Checkpoint save interval in iterations (default: 5)
random_seed: Random seed for reproducibility (default: 42)
@ -74,9 +74,9 @@ def create_train_config(
scheduler_fn=scheduler_fn,
ckpt_dir=test_dir,
n_epoch=n_epoch,
batch_size=batch_size,
batch_per_device=batch_per_device,
ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps,
grad_accum_steps=grad_accum_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
device_type=device,

View File

@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=2,
batch_per_device=2,
ckpt_interval=3,
accumulation_steps=1,
grad_accum_steps=1,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"],

View File

@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
n_epoch=2,
batch_size=2,
batch_per_device=2,
ckpt_interval=1,
accumulation_steps=2,
grad_accum_steps=2,
random_seed=np.random.randint(1e4),
device_type=base_test_env["device"],
)

View File

@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
"""Test training with different batch sizes"""
batch_sizes = [1, 2, 4, 8]
for batch_size in batch_sizes:
for batch_per_device in batch_sizes:
train_config = train_config_factory(
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_size=batch_size,
batch_per_device=batch_per_device,
)
assert train_config.batch_size == batch_size
assert train_config.batch_per_device == batch_per_device
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
"""Test training with different gradient accumulation steps"""
accumulation_steps_list = [1, 2, 4]
grad_accum_steps_list = [1, 2, 4]
for accumulation_steps in accumulation_steps_list:
for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory(
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_size=2,
accumulation_steps=accumulation_steps,
batch_per_device=2,
grad_accum_steps=grad_accum_steps,
)
trainer = Trainer(train_config)
trainer.train()
assert train_config.accumulation_steps == accumulation_steps
assert train_config.grad_accum_steps == grad_accum_steps
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
"""Test training with memory-efficient configurations"""
# Test with smaller batch sizes and gradient checkpointing
small_batch_configs = [
{"batch_size": 1, "accumulation_steps": 8},
{"batch_size": 2, "accumulation_steps": 4},
{"batch_size": 4, "accumulation_steps": 2},
{"batch_per_device": 1, "grad_accum_steps": 8},
{"batch_per_device": 2, "grad_accum_steps": 4},
{"batch_per_device": 4, "grad_accum_steps": 2},
]
for config in small_batch_configs:
@ -54,8 +54,9 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
batch_size=config["batch_size"],
accumulation_steps=config["accumulation_steps"],
batch_per_device=config["batch_per_device"],
grad_accum_steps=config["grad_accum_steps"],
)
assert train_config.accumulation_steps == config["accumulation_steps"]
assert train_config.grad_accum_steps == config["grad_accum_steps"]
assert train_config.batch_per_device == config["batch_per_device"]