refactor: 训练循环改为两重迭代并统一参数命名
- 训练循环从三重(epoch→batched→batch)改为二重(epoch→batch) - batch_size → batch_per_device, accumulation_steps → grad_accum_steps - scheduler 移入 step block 对齐 optimizer 更新步 - GradientClippingCallback 改用 on_step_begin 避免零梯度裁剪 - 移除 _train_impl 误导性的 -> Checkpoint 标注 - total_steps 修除为向下取整并精简为一行 - warmup_steps 改为 warmup_ratio (默认0.05)
This commit is contained in:
parent
7dea929788
commit
d7a7f570ed
30
README.md
30
README.md
|
|
@ -78,15 +78,27 @@ Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) i
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type seq \
|
|
||||||
--data_root_path /path/to/dataset \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path /path/to/model \
|
--nprocs=4 \
|
||||||
--batch_size 4 \
|
--train_type=sft \
|
||||||
--accumulation_steps 8 \
|
--data_root_path=/path/to/dataset \
|
||||||
--max_lr 3e-4 \
|
--param_path=/path/to/model \
|
||||||
--warmup_steps 1000 \
|
--batch_per_device=4 \
|
||||||
--n_epoch 1
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.99 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=1e-5 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.1 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
|
|
|
||||||
|
|
@ -84,15 +84,27 @@ python scripts/demo/download.py
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type seq \
|
|
||||||
--data_root_path /path/to/dataset \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path /path/to/model \
|
--nprocs=4 \
|
||||||
--batch_size 4 \
|
--train_type=sft \
|
||||||
--accumulation_steps 8 \
|
--data_root_path=/path/to/dataset \
|
||||||
--max_lr 3e-4 \
|
--param_path=/path/to/model \
|
||||||
--warmup_steps 1000 \
|
--batch_per_device=4 \
|
||||||
--n_epoch 1
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.99 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=1e-5 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.1 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md)。
|
完整参数列表见[参数说明](./params.md)。
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,9 @@ classDiagram
|
||||||
+int n_shared_experts
|
+int n_shared_experts
|
||||||
+int n_activated_experts
|
+int n_activated_experts
|
||||||
+str moe_topk_method
|
+str moe_topk_method
|
||||||
|
+Optional[int] kv_lora_rank
|
||||||
|
+Optional[int] qk_nope_head_dim
|
||||||
|
+Optional[int] qk_rope_head_dim
|
||||||
+load(config_path) ModelConfig
|
+load(config_path) ModelConfig
|
||||||
+save(config_path)
|
+save(config_path)
|
||||||
}
|
}
|
||||||
|
|
@ -41,8 +44,8 @@ classDiagram
|
||||||
+Callable optimizer_fn
|
+Callable optimizer_fn
|
||||||
+Callable scheduler_fn
|
+Callable scheduler_fn
|
||||||
+int n_epoch
|
+int n_epoch
|
||||||
+int batch_size
|
+int batch_per_device
|
||||||
+int accumulation_steps
|
+int grad_accum_steps
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+int start_epoch
|
+int start_epoch
|
||||||
+int start_batch
|
+int start_batch
|
||||||
|
|
@ -69,7 +72,7 @@ classDiagram
|
||||||
class BaseDataset {
|
class BaseDataset {
|
||||||
+int window_size
|
+int window_size
|
||||||
+int stride
|
+int stride
|
||||||
+BaseStorage storage
|
+Optional[BaseStorage] storage
|
||||||
+load(load_path, storage_type, tokenizer)
|
+load(load_path, storage_type, tokenizer)
|
||||||
+__getitem__(index)
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
|
|
@ -126,8 +129,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
+int start_epoch
|
+int epoch
|
||||||
+int start_iter
|
+int iter
|
||||||
}
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
|
|
@ -155,7 +158,7 @@ classDiagram
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(model_type) decorator
|
+register(model_type) decorator
|
||||||
+get_component_class(model_type) Type
|
+get_component_class(model_type) Type
|
||||||
+from_pretrained(path, disable_random_init) nn.Module
|
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||||
+save_pretrained(save_directory)
|
+save_pretrained(save_directory)
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
}
|
}
|
||||||
|
|
@ -167,7 +170,7 @@ classDiagram
|
||||||
+ModuleList layers
|
+ModuleList layers
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
+Linear lm_head
|
+Linear lm_head
|
||||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
|
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
||||||
+load_state_dict(state_dict)
|
+load_state_dict(state_dict)
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
@ -185,6 +188,7 @@ classDiagram
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
+int n_rep
|
+int n_rep
|
||||||
|
+int layer_id
|
||||||
+bool use_qk_norm
|
+bool use_qk_norm
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+Linear q_proj, k_proj, v_proj, o_proj
|
+Linear q_proj, k_proj, v_proj, o_proj
|
||||||
|
|
@ -201,6 +205,7 @@ classDiagram
|
||||||
+int qk_nope_head_dim
|
+int qk_nope_head_dim
|
||||||
+int qk_rope_head_dim
|
+int qk_rope_head_dim
|
||||||
+int n_rep
|
+int n_rep
|
||||||
|
+int layer_id
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||||
+Linear o_proj
|
+Linear o_proj
|
||||||
|
|
@ -215,6 +220,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class DeepSeekMoE {
|
class DeepSeekMoE {
|
||||||
|
+int dim
|
||||||
+int n_routed_experts
|
+int n_routed_experts
|
||||||
+int n_shared_experts
|
+int n_shared_experts
|
||||||
+int n_activated_experts
|
+int n_activated_experts
|
||||||
|
|
@ -236,6 +242,7 @@ classDiagram
|
||||||
class RMSNorm {
|
class RMSNorm {
|
||||||
+Parameter weight
|
+Parameter weight
|
||||||
+float norm_eps
|
+float norm_eps
|
||||||
|
+tuple normalized_shape
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -299,7 +306,6 @@ classDiagram
|
||||||
+TrainConfig train_config
|
+TrainConfig train_config
|
||||||
+List[TrainCallback] callbacks
|
+List[TrainCallback] callbacks
|
||||||
+train(checkpoint)
|
+train(checkpoint)
|
||||||
+_build_context(checkpoint) TrainContext
|
|
||||||
+_get_default_callbacks() List[TrainCallback]
|
+_get_default_callbacks() List[TrainCallback]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -324,7 +330,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseStrategy {
|
class BaseStrategy {
|
||||||
+nn.Module model
|
+Union[Callable, nn.Module] model
|
||||||
+str device
|
+str device
|
||||||
+compute_loss(batch) Tensor
|
+compute_loss(batch) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -332,7 +338,7 @@ classDiagram
|
||||||
class StrategyFactory {
|
class StrategyFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(model, train_type, device, **kwargs) BaseStrategy
|
+create(train_type, model, device, **kwargs) BaseStrategy
|
||||||
}
|
}
|
||||||
|
|
||||||
class SEQStrategy {
|
class SEQStrategy {
|
||||||
|
|
@ -400,7 +406,7 @@ classDiagram
|
||||||
|
|
||||||
class GradientClippingCallback {
|
class GradientClippingCallback {
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+on_step_end(context)
|
+on_step_begin(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
|
|
@ -459,10 +465,7 @@ classDiagram
|
||||||
+TaskManager _task_mgr
|
+TaskManager _task_mgr
|
||||||
+bool _running
|
+bool _running
|
||||||
+Thread _loop_thread
|
+Thread _loop_thread
|
||||||
+int max_batch_size
|
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+int max_prompt_len
|
|
||||||
+int page_size
|
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
|
|
@ -500,10 +503,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class Storage {
|
class Storage {
|
||||||
+int n_layers
|
|
||||||
+int page_size
|
+int page_size
|
||||||
+int head_dim
|
|
||||||
+int n_kv_heads
|
|
||||||
+Tensor k_cache
|
+Tensor k_cache
|
||||||
+Tensor v_cache
|
+Tensor v_cache
|
||||||
+write(layer_id, page_table, start_pos, k, v)
|
+write(layer_id, page_table, start_pos, k, v)
|
||||||
|
|
@ -675,7 +675,6 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class AnthropicHandler {
|
class AnthropicHandler {
|
||||||
+List[str] stop_sequences
|
|
||||||
+build_prompt() str
|
+build_prompt() str
|
||||||
+create_response_id() str
|
+create_response_id() str
|
||||||
+on_token(ctx, token, stop_checker) Optional[str]
|
+on_token(ctx, token, stop_checker) Optional[str]
|
||||||
|
|
@ -704,7 +703,7 @@ classDiagram
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class Functions {
|
class Functions {
|
||||||
+spawn_parallel_fn(fn, nprocs)
|
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||||
+get_current_device() str
|
+get_current_device() str
|
||||||
+get_world_size() int
|
+get_world_size() int
|
||||||
|
|
@ -878,4 +877,4 @@ classDiagram
|
||||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
|
|
||||||
> Document Update Time: 2026-05-15
|
> Document Update Time: 2026-05-16
|
||||||
|
|
|
||||||
|
|
@ -10,14 +10,14 @@
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
| `--batch_size` | Batch size | 1 |
|
| `--batch_per_device` | Batch size per device | 1 |
|
||||||
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
### Learning Rate Scheduling
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--warmup_steps` | Warmup steps | 1000 |
|
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
|
||||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||||
|
|
||||||
|
|
@ -69,90 +69,29 @@
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type seq \
|
|
||||||
--data_root_path /path/to/dataset \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path /path/to/model \
|
--nprocs=4 \
|
||||||
--n_epoch 3 \
|
--train_type=sft \
|
||||||
--batch_size 4 \
|
--data_root_path=/path/to/dataset \
|
||||||
--accumulation_steps 8 \
|
--param_path=/path/to/model \
|
||||||
--max_lr 3e-4 \
|
--batch_per_device=4 \
|
||||||
--warmup_steps 2000 \
|
--grad_accum_steps=8 \
|
||||||
--max_grad_norm 1.0 \
|
--warmup_ratio=0.05 \
|
||||||
--ckpt_interval 5000 \
|
--max_lr=1e-4 \
|
||||||
--ckpt_dir ./checkpoints \
|
--max_grad_norm=1.0 \
|
||||||
--num_workers 4 \
|
--adamw_beta1=0.99 \
|
||||||
--nprocs 1 \
|
--adamw_beta2=0.95 \
|
||||||
--device_type cuda
|
--adamw_weight_decay=1e-5 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.1 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Generation Parameters
|
> Document Update Time: 2026-05-16
|
||||||
|
|
||||||
### GenerationRequest Parameters
|
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
|
||||||
|-----------|-------------|---------------|
|
|
||||||
| `messages` | List of message dictionaries (role, content) | required |
|
|
||||||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
|
||||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
|
||||||
| `top_k` | Top-k sampling count | 50 |
|
|
||||||
| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) |
|
|
||||||
| `stream` | Whether to stream output | False |
|
|
||||||
|
|
||||||
### Usage Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
from astrai.inference import InferenceEngine, GenerationRequest
|
|
||||||
|
|
||||||
# Load model using AutoModel
|
|
||||||
model = AutoModel.from_pretrained("your_model_dir")
|
|
||||||
|
|
||||||
# Load tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
|
||||||
|
|
||||||
# Create engine with separate model and tokenizer
|
|
||||||
engine = InferenceEngine(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build request with messages format
|
|
||||||
request = GenerationRequest(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
],
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50,
|
|
||||||
max_tokens=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate (streaming)
|
|
||||||
for token in engine.generate_with_request(request):
|
|
||||||
print(token, end="", flush=True)
|
|
||||||
|
|
||||||
# Or use simple generate interface
|
|
||||||
result = engine.generate(
|
|
||||||
prompt="Hello",
|
|
||||||
stream=False,
|
|
||||||
max_tokens=1024,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Generation Modes
|
|
||||||
|
|
||||||
| Mode | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `stream=True` | Streaming output, yields token by token |
|
|
||||||
| `stream=False` | Non-streaming output, returns complete result |
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-15
|
|
||||||
|
|
@ -65,24 +65,24 @@ The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per posi
|
||||||
|
|
||||||
## Training Loop
|
## Training Loop
|
||||||
|
|
||||||
Nested loop: **epoch** → **step** (accumulation window) → **batch**.
|
Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_steps` batches.
|
||||||
|
|
||||||
```
|
```
|
||||||
on_train_begin
|
on_train_begin
|
||||||
on_epoch_begin
|
on_epoch_begin
|
||||||
for steps in batched(dataloader, accumulation_steps):
|
for batch in dataloader:
|
||||||
on_step_begin
|
on_batch_begin
|
||||||
step_batch_nums = len(steps)
|
loss = strategy(batch)
|
||||||
for batch in steps:
|
(loss / grad_accum_steps).backward()
|
||||||
on_batch_begin
|
iteration += 1
|
||||||
loss = strategy(batch)
|
on_batch_end
|
||||||
(loss / step_batch_nums).backward()
|
|
||||||
iteration += 1
|
if iteration % grad_accum_steps == 0:
|
||||||
on_batch_end
|
on_step_begin
|
||||||
on_step_end
|
optimizer.step()
|
||||||
optimizer.step()
|
optimizer.zero_grad()
|
||||||
optimizer.zero_grad()
|
on_step_end
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
```
|
```
|
||||||
|
|
@ -91,9 +91,9 @@ on_train_end
|
||||||
|
|
||||||
| Hook | Fires | Default callback |
|
| Hook | Fires | Default callback |
|
||||||
|------|-------|-----------------|
|
|------|-------|-----------------|
|
||||||
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
|
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
|
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||||
|
|
||||||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||||||
|
|
||||||
|
|
@ -162,7 +162,7 @@ Checkpoint(state_dict, epoch, iteration, extra)
|
||||||
└── load(save_dir) broadcasts metadata from rank-0
|
└── load(save_dir) broadcasts metadata from rank-0
|
||||||
```
|
```
|
||||||
|
|
||||||
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
|
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||||
|
|
||||||
## TrainContextBuilder (Builder Pattern)
|
## TrainContextBuilder (Builder Pattern)
|
||||||
|
|
||||||
|
|
@ -183,17 +183,29 @@ context = (
|
||||||
## Training CLI
|
## Training CLI
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py \
|
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
--train_type seq \
|
|
||||||
--data_root_path /path/to/data \
|
nohup python scripts/tools/train.py \
|
||||||
--param_path /path/to/model \
|
--nprocs=4 \
|
||||||
--batch_size 4 \
|
--train_type=sft \
|
||||||
--accumulation_steps 8 \
|
--data_root_path=/path/to/dataset \
|
||||||
--max_lr 3e-4 \
|
--param_path=/path/to/model \
|
||||||
--warmup_steps 1000 \
|
--batch_per_device=4 \
|
||||||
--n_epoch 1
|
--grad_accum_steps=8 \
|
||||||
|
--warmup_ratio=0.05 \
|
||||||
|
--max_lr=1e-4 \
|
||||||
|
--max_grad_norm=1.0 \
|
||||||
|
--adamw_beta1=0.99 \
|
||||||
|
--adamw_beta2=0.95 \
|
||||||
|
--adamw_weight_decay=1e-5 \
|
||||||
|
--window_size=2048 \
|
||||||
|
--ckpt_interval=10000 \
|
||||||
|
--ckpt_dir=./checkpoint \
|
||||||
|
--random_seed=3407 \
|
||||||
|
--label_smoothing=0.1 \
|
||||||
|
> out.log 2> err.log &
|
||||||
```
|
```
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-15
|
> Document Update Time: 2026-05-16
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,10 @@ class TrainConfig:
|
||||||
default=None, metadata={"help": "Scheduler factory for training."}
|
default=None, metadata={"help": "Scheduler factory for training."}
|
||||||
)
|
)
|
||||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||||
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
batch_per_device: int = field(
|
||||||
accumulation_steps: int = field(
|
default=4, metadata={"help": "Batch size per device."}
|
||||||
|
)
|
||||||
|
grad_accum_steps: int = field(
|
||||||
default=1, metadata={"help": "Number of iterations between steps."}
|
default=1, metadata={"help": "Number of iterations between steps."}
|
||||||
)
|
)
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
|
|
|
||||||
|
|
@ -79,8 +79,7 @@ class GradientClippingCallback(TrainCallback):
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_end(self, context: TrainContext):
|
def on_step_begin(self, context: TrainContext):
|
||||||
_ = context
|
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class TrainContextBuilder:
|
||||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
sampler_offset = context.iteration * cfg.batch_size
|
sampler_offset = context.iteration * cfg.batch_per_device
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
data_source=cfg.dataset,
|
data_source=cfg.dataset,
|
||||||
start_epoch=context.epoch,
|
start_epoch=context.epoch,
|
||||||
|
|
@ -79,7 +79,7 @@ class TrainContextBuilder:
|
||||||
)
|
)
|
||||||
context.dataloader = DataLoader(
|
context.dataloader = DataLoader(
|
||||||
cfg.dataset,
|
cfg.dataset,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_per_device,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
pin_memory=cfg.pin_memory,
|
pin_memory=cfg.pin_memory,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
from itertools import batched
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
|
|
@ -33,11 +32,6 @@ class Trainer:
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
|
||||||
return (
|
|
||||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
|
|
@ -45,49 +39,47 @@ class Trainer:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
config = self.train_config
|
cfg = self.train_config
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(
|
||||||
self._train_impl,
|
self._train_impl,
|
||||||
backend=config.backend,
|
backend=cfg.backend,
|
||||||
world_size=config.nprocs,
|
world_size=cfg.nprocs,
|
||||||
master_addr=config.master_addr,
|
master_addr=cfg.master_addr,
|
||||||
master_port=config.master_port,
|
master_port=cfg.master_port,
|
||||||
device_type=config.device_type,
|
device_type=cfg.device_type,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
context = self._build_context(checkpoint)
|
cfg = self.train_config
|
||||||
|
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||||
self._call_callbacks("on_train_begin", context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
grad_accum_steps = cfg.grad_accum_steps
|
||||||
|
|
||||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, cfg.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for steps in batched(context.dataloader, accumulation_steps):
|
for batch in context.dataloader:
|
||||||
self._call_callbacks("on_step_begin", context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
|
loss = context.strategy(batch)
|
||||||
|
context.loss = loss.item()
|
||||||
|
stand_loss = loss / grad_accum_steps
|
||||||
|
stand_loss.backward()
|
||||||
|
context.iteration += 1
|
||||||
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
step_batch_nums = len(steps)
|
if context.iteration % grad_accum_steps == 0:
|
||||||
for batch in steps:
|
self._call_callbacks("on_step_begin", context)
|
||||||
self._call_callbacks("on_batch_begin", context)
|
context.optimizer.step()
|
||||||
loss = context.strategy(batch)
|
context.optimizer.zero_grad()
|
||||||
context.loss = loss.item()
|
self._call_callbacks("on_step_end", context)
|
||||||
context.iteration += 1
|
|
||||||
|
|
||||||
stand_loss = loss / step_batch_nums
|
if context.scheduler:
|
||||||
stand_loss.backward()
|
context.scheduler.step()
|
||||||
self._call_callbacks("on_batch_end", context)
|
|
||||||
|
|
||||||
self._call_callbacks("on_step_end", context)
|
|
||||||
context.optimizer.step()
|
|
||||||
context.optimizer.zero_grad()
|
|
||||||
|
|
||||||
if context.scheduler:
|
|
||||||
context.scheduler.step()
|
|
||||||
|
|
||||||
self._call_callbacks("on_epoch_end", context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,18 +42,20 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accumulation_steps",
|
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad_accum_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of iterations between each optimizer step.",
|
help="Number of iterations between each optimizer step.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warmup_steps",
|
"--warmup_ratio",
|
||||||
type=int,
|
type=float,
|
||||||
default=1000,
|
default=0.05,
|
||||||
help="Number of warmup steps for LR scheduler.",
|
help="Fraction of total steps used for LR warmup.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||||
|
|
@ -177,24 +179,25 @@ def create_scheduler(
|
||||||
return SchedulerFactory.create(optimizer, **kwargs)
|
return SchedulerFactory.create(optimizer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def ceil_div(a: int, b: int) -> int:
|
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
return (a + b - 1) // b
|
return model.module.state_dict()
|
||||||
|
|
||||||
|
|
||||||
def compute_total_steps(
|
def compute_total_steps(
|
||||||
dataset_len: int,
|
dataset_len: int,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
batch_size: int,
|
batch_per_device: int,
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
accumulation_steps: int,
|
grad_accum_steps: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
|
|
||||||
|
def ceil_div(a: int, b: int) -> int:
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
samples_per_replica = ceil_div(dataset_len, nprocs)
|
samples_per_replica = ceil_div(dataset_len, nprocs)
|
||||||
batches_per_replica = ceil_div(samples_per_replica, batch_size)
|
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
|
||||||
return ceil_div(batches_per_replica, accumulation_steps) * n_epoch
|
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
|
||||||
|
return total_steps
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
|
||||||
return model.module.state_dict()
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
|
@ -203,11 +206,11 @@ def train(
|
||||||
data_root_path: str,
|
data_root_path: str,
|
||||||
max_lr: float,
|
max_lr: float,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
batch_size: int,
|
batch_per_device: int,
|
||||||
start_epoch: int,
|
start_epoch: int,
|
||||||
start_batch: int,
|
start_batch: int,
|
||||||
accumulation_steps: int,
|
grad_accum_steps: int,
|
||||||
warmup_steps: int,
|
warmup_ratio: float,
|
||||||
ckpt_interval: int,
|
ckpt_interval: int,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
|
|
@ -277,8 +280,10 @@ def train(
|
||||||
)
|
)
|
||||||
|
|
||||||
total_steps = compute_total_steps(
|
total_steps = compute_total_steps(
|
||||||
len(dataset), n_epoch, batch_size, nprocs, accumulation_steps
|
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
|
||||||
)
|
)
|
||||||
|
warmup_steps = int(warmup_ratio * total_steps)
|
||||||
|
|
||||||
scheduler_fn = partial(
|
scheduler_fn = partial(
|
||||||
create_scheduler,
|
create_scheduler,
|
||||||
**{
|
**{
|
||||||
|
|
@ -296,11 +301,11 @@ def train(
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_per_device=batch_per_device,
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
start_batch=start_batch,
|
start_batch=start_batch,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
accumulation_steps=accumulation_steps,
|
grad_accum_steps=grad_accum_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ def create_train_config(
|
||||||
device: str,
|
device: str,
|
||||||
strategy: str = "seq",
|
strategy: str = "seq",
|
||||||
n_epoch: int = 1,
|
n_epoch: int = 1,
|
||||||
batch_size: int = 2,
|
batch_per_device: int = 2,
|
||||||
accumulation_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
max_grad_norm: float = 1.0,
|
max_grad_norm: float = 1.0,
|
||||||
ckpt_interval: int = 5,
|
ckpt_interval: int = 5,
|
||||||
random_seed: int = 42,
|
random_seed: int = 42,
|
||||||
|
|
@ -47,8 +47,8 @@ def create_train_config(
|
||||||
device: Device type ("cuda" or "cpu")
|
device: Device type ("cuda" or "cpu")
|
||||||
strategy: Training strategy type (default: "seq")
|
strategy: Training strategy type (default: "seq")
|
||||||
n_epoch: Number of epochs (default: 1)
|
n_epoch: Number of epochs (default: 1)
|
||||||
batch_size: Batch size (default: 2)
|
batch_per_device: Batch size per device (default: 2)
|
||||||
accumulation_steps: Gradient accumulation steps (default: 1)
|
grad_accum_steps: Gradient accumulation steps (default: 1)
|
||||||
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
max_grad_norm: Maximum gradient norm for clipping (default: 1.0)
|
||||||
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
ckpt_interval: Checkpoint save interval in iterations (default: 5)
|
||||||
random_seed: Random seed for reproducibility (default: 42)
|
random_seed: Random seed for reproducibility (default: 42)
|
||||||
|
|
@ -74,9 +74,9 @@ def create_train_config(
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=test_dir,
|
ckpt_dir=test_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_per_device=batch_per_device,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
accumulation_steps=accumulation_steps,
|
grad_accum_steps=grad_accum_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
device_type=device,
|
device_type=device,
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,9 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=3,
|
ckpt_interval=3,
|
||||||
accumulation_steps=1,
|
grad_accum_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"],
|
device_type=base_test_env["device"],
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=1,
|
ckpt_interval=1,
|
||||||
accumulation_steps=2,
|
grad_accum_steps=2,
|
||||||
random_seed=np.random.randint(1e4),
|
random_seed=np.random.randint(1e4),
|
||||||
device_type=base_test_env["device"],
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,45 +7,45 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
||||||
"""Test training with different batch sizes"""
|
"""Test training with different batch sizes"""
|
||||||
batch_sizes = [1, 2, 4, 8]
|
batch_sizes = [1, 2, 4, 8]
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
for batch_per_device in batch_sizes:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_size=batch_size,
|
batch_per_device=batch_per_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.batch_size == batch_size
|
assert train_config.batch_per_device == batch_per_device
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
||||||
"""Test training with different gradient accumulation steps"""
|
"""Test training with different gradient accumulation steps"""
|
||||||
accumulation_steps_list = [1, 2, 4]
|
grad_accum_steps_list = [1, 2, 4]
|
||||||
|
|
||||||
for accumulation_steps in accumulation_steps_list:
|
for grad_accum_steps in grad_accum_steps_list:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_size=2,
|
batch_per_device=2,
|
||||||
accumulation_steps=accumulation_steps,
|
grad_accum_steps=grad_accum_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
assert train_config.accumulation_steps == accumulation_steps
|
assert train_config.grad_accum_steps == grad_accum_steps
|
||||||
|
|
||||||
|
|
||||||
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
||||||
"""Test training with memory-efficient configurations"""
|
"""Test training with memory-efficient configurations"""
|
||||||
# Test with smaller batch sizes and gradient checkpointing
|
# Test with smaller batch sizes and gradient checkpointing
|
||||||
small_batch_configs = [
|
small_batch_configs = [
|
||||||
{"batch_size": 1, "accumulation_steps": 8},
|
{"batch_per_device": 1, "grad_accum_steps": 8},
|
||||||
{"batch_size": 2, "accumulation_steps": 4},
|
{"batch_per_device": 2, "grad_accum_steps": 4},
|
||||||
{"batch_size": 4, "accumulation_steps": 2},
|
{"batch_per_device": 4, "grad_accum_steps": 2},
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
|
|
@ -54,8 +54,9 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
batch_size=config["batch_size"],
|
batch_per_device=config["batch_per_device"],
|
||||||
accumulation_steps=config["accumulation_steps"],
|
grad_accum_steps=config["grad_accum_steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
assert train_config.grad_accum_steps == config["grad_accum_steps"]
|
||||||
|
assert train_config.batch_per_device == config["batch_per_device"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue