200 lines
5.8 KiB
Markdown
200 lines
5.8 KiB
Markdown
# Training
|
||
|
||
## Model Architecture
|
||
|
||
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, Chinese–English bilingual.
|
||
|
||
```mermaid
|
||
flowchart TB
|
||
subgraph Layers["Transformer Layers"]
|
||
direction TB
|
||
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
||
B --> C[Transformer Block\nLayer ...]
|
||
C --> D[Transformer Block\nLayer ...]
|
||
D --> E[RMSNorm]
|
||
E --> F[Linear]
|
||
F --> G[SoftMax]
|
||
end
|
||
|
||
subgraph TransformerBlock["Transformer Block"]
|
||
direction TB
|
||
H[x] --> I[RMSNorm]
|
||
I --> J[Linear → Q/K/V]
|
||
J --> K[Q]; J --> L[K]; J --> M[V]
|
||
K --> N[RoPE]; L --> O[RoPE]
|
||
N --> P["Q @ K^T / sqrt(d)"]; O --> P
|
||
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
|
||
R --> S[Linear]; S --> T[+]; H --> T
|
||
T --> U[RMSNorm]
|
||
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
|
||
V --> X[SiLU]; X --> Y[×]; W --> Y
|
||
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
|
||
AA --> BB[x']
|
||
end
|
||
```
|
||
|
||
### Autoregression
|
||
|
||
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||
|
||
### Causal Mask
|
||
|
||
```
|
||
sequence : [[1, 2, 3, 4, 5, 6]]
|
||
input_ids: [[1, 2, 3, 4, 5]]
|
||
target_ids: [[2, 3, 4, 5, 6]]
|
||
```
|
||
|
||
Lower-triangular mask prevents attending to future positions:
|
||
|
||
```
|
||
[[0, -inf, -inf, -inf, -inf],
|
||
[0, 0, -inf, -inf, -inf],
|
||
[0, 0, 0, -inf, -inf],
|
||
[0, 0, 0, 0, -inf],
|
||
[0, 0, 0, 0, 0]]
|
||
```
|
||
|
||
### Rotary Position Embedding (RoPE)
|
||
|
||
RoPE embeds position into Q/K vectors via complex rotation:
|
||
|
||
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||
|
||
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
|
||
|
||
## Training Loop
|
||
|
||
Nested loop: **epoch** → **step** (accumulation window) → **batch**.
|
||
|
||
```
|
||
on_train_begin
|
||
on_epoch_begin
|
||
for steps in batched(dataloader, accumulation_steps):
|
||
on_step_begin
|
||
step_batch_nums = len(steps)
|
||
for batch in steps:
|
||
on_batch_begin
|
||
loss = strategy(batch)
|
||
(loss / step_batch_nums).backward()
|
||
iteration += 1
|
||
on_batch_end
|
||
on_step_end
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
scheduler.step()
|
||
on_epoch_end
|
||
on_train_end
|
||
```
|
||
|
||
### Callback Lifecycle
|
||
|
||
| Hook | Fires | Default callback |
|
||
|------|-------|-----------------|
|
||
| `on_step_end` | Every accumulation window | `GradientClippingCallback` |
|
||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||
| `on_train_end` | Training ends | `CheckpointCallback` (final save) |
|
||
|
||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||
|
||
## Strategies
|
||
|
||
### SEQ (Pre-training)
|
||
|
||
Next-token cross-entropy with optional label smoothing:
|
||
|
||
$$
|
||
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
||
$$
|
||
|
||
Keys: `input_ids`, `target_ids`
|
||
|
||
### SFT (Supervised Fine-Tuning)
|
||
|
||
Masked cross-entropy (`ignore_index=-100`) over response tokens:
|
||
|
||
$$
|
||
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
||
$$
|
||
|
||
Keys: `input_ids`, `target_ids`, `loss_mask`
|
||
|
||
### DPO (Direct Preference Optimization)
|
||
|
||
Frozen reference model, preference margin via log-ratio:
|
||
|
||
$$
|
||
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
||
$$
|
||
|
||
Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||
|
||
### GRPO (Group Relative Policy Optimization)
|
||
|
||
On-policy PPO with group-normalized advantages:
|
||
|
||
$$
|
||
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
||
$$
|
||
|
||
$$
|
||
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
||
$$
|
||
|
||
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`.
|
||
|
||
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||
|
||
## LR Schedulers
|
||
|
||
| Type | Class | Description |
|
||
|------|-------|-------------|
|
||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||
|
||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||
|
||
## Checkpoint
|
||
|
||
```
|
||
Checkpoint(state_dict, epoch, iteration, extra)
|
||
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
|
||
└── load(save_dir) broadcasts metadata from rank-0
|
||
```
|
||
|
||
Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data.
|
||
|
||
## TrainContextBuilder (Builder Pattern)
|
||
|
||
```python
|
||
context = (
|
||
TrainContextBuilder(config)
|
||
.with_checkpoint(checkpoint)
|
||
.build()
|
||
)
|
||
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||
```
|
||
|
||
- Loads checkpoint weights if provided
|
||
- Wraps model with `parallel_wrapper` if `nprocs > 1`
|
||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
||
|
||
## Training CLI
|
||
|
||
```bash
|
||
python scripts/tools/train.py \
|
||
--train_type seq \
|
||
--data_root_path /path/to/data \
|
||
--param_path /path/to/model \
|
||
--batch_size 4 \
|
||
--accumulation_steps 8 \
|
||
--max_lr 3e-4 \
|
||
--warmup_steps 1000 \
|
||
--n_epoch 1
|
||
```
|
||
|
||
Full parameter reference at [params.md](params.md).
|
||
|
||
> Document Update Time: 2026-05-15
|