AstrAI/assets/docs/training.md

7.2 KiB
Raw Blame History

Training

Model Architecture

The model uses a decoder-only Transformer with GQA (Grouped Query Attention) and optional MLA (Multi-head Latent Attention). 1.0 billion parameters, ChineseEnglish bilingual.

flowchart TB
    subgraph Layers["Transformer Layers"]
        direction TB
        A[Input Embedding] --> B[Transformer Block\nLayer 1]
        B --> C[Transformer Block\nLayer ...]
        C --> D[Transformer Block\nLayer ...]
        D --> E[RMSNorm]
        E --> F[Linear]
        F --> G[SoftMax]
    end

    subgraph TransformerBlock["Transformer Block"]
        direction TB
        H[x] --> I[RMSNorm]
        I --> J[Linear → Q/K/V]
        J --> K[Q]; J --> L[K]; J --> M[V]
        K --> N[RoPE]; L --> O[RoPE]
        N --> P["Q @ K^T / sqrt(d)"]; O --> P
        P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
        R --> S[Linear]; S --> T[+]; H --> T
        T --> U[RMSNorm]
        U --> V["Linear (gate)"]; U --> W["Linear (up)"]
        V --> X[SiLU]; X --> Y[×]; W --> Y
        Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
        AA --> BB[x']
    end

Autoregression

Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.

Causal Mask

sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]

Lower-triangular mask prevents attending to future positions:

[[0, -inf, -inf, -inf, -inf],
 [0,    0, -inf, -inf, -inf],
 [0,    0,    0, -inf, -inf],
 [0,    0,    0,    0, -inf],
 [0,    0,    0,    0,    0]]

Rotary Position Embedding (RoPE)

RoPE embeds position into Q/K vectors via complex rotation:

 q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j 

The complex rotation freqs_cis is pre-computed once (cos, sin pairs per position). apply_rotary_emb multiplies Q/K as complex numbers.

Training Loop

Two-level loop: epochbatch. Optimizer step fires every grad_accum_steps batches.

on_train_begin
  on_epoch_begin
    for batch in dataloader:
      on_batch_begin
      with executor.accumulate(model):
        loss = strategy(batch)
        (loss / grad_accum_steps).backward()
        iteration += 1
      on_batch_end

      if executor.sync_gradients:
        on_optimizer_step
        optimizer.step()
        optimizer.zero_grad()

      scheduler.step()  # called every iteration
    on_epoch_end
on_train_end

Callback Lifecycle

Hook Fires Default callback
on_train_begin Before training starts GradientCheckpointingCallback
on_optimizer_step Every accumulation window GradientClippingCallback, ValidationCallback
on_batch_end Every batch CheckpointCallback, MetricLoggerCallback, ProgressBarCallback
on_train_end Training ends CheckpointCallback, MetricLoggerCallback (final save)

Default callbacks (in order): gradient_checkpointing (activation checkpointing, optional), checkpoint (safetensors, rank-0), metric_logger (JSONL, rank-0), progress_bar (tqdm), gradient_clipping, validation (periodic validation on val_dataset).

Strategies

SEQ (Pre-training)

Next-token cross-entropy with optional label smoothing:


L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)

Keys: input_ids, target_ids

SFT (Supervised Fine-Tuning)

Masked cross-entropy (ignore_index=-100) over response tokens:


L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)

Keys: input_ids, target_ids, loss_mask

DPO (Direct Preference Optimization)

Frozen reference model, preference margin via log-ratio:


L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]

Parameters: beta=0.1. Keys: chosen, rejected, chosen_mask, rejected_mask.

GRPO (Group Relative Policy Optimization)

On-policy PPO with group-normalized advantages:


\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}

L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]

Parameters: group_size=4, clip_eps=0.2, kl_coef=0.01, sync_interval=200.

Keys: prompts, responses, masks, rewards.

LR Schedulers

Type Class Description
Cosine CosineScheduler Linear warmup → cosine decay to min_rate
SGDR SGDRScheduler Cosine annealing with warm restarts (t_mult=2)

Created by SchedulerFactory.create(optimizer, schedule_type, **kwargs).

Gradient Checkpointing

Trades compute for memory by recomputing activations during backward pass. Specify module types via gradient_checkpointing_modules:

from astrai.model.components.decoder_block import DecoderBlock
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])

Callback wraps each DecoderBlock.forward with torch.utils.checkpoint.checkpoint(use_reentrant=False), compatible with torch.compile. Uses nn.Module.apply() for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.

Checkpoint

Checkpoint(state_dict, epoch, iteration, extra, meta)
  ├── save(save_dir)    rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
  └── load(save_dir)    broadcasts metadata from rank-0

Optimizer/scheduler state persisted by default via Checkpoint.extra.
Training config (TrainConfig.to_dict()) saved into meta.json during training via CheckpointCallback.

TrainContextBuilder (Builder Pattern)

context = (
    TrainContextBuilder(config)
        .with_checkpoint(checkpoint)
        .build()
)
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
  • Loads checkpoint weights if provided
  • Creates executor via ExecutorFactory.create(parallel_mode, **executor_kwargs)
  • Calls executor.prepare(model, optimizer, dataloader, scheduler) for model distribution (e.g. DDP) + gradient accumulation wrappers
  • Creates ResumableDistributedSampler for shuffle+resume
  • Builds strategy via StrategyFactory.create(train_type, ...)

Training CLI

export CUDA_VISIBLE_DEVICES=0,1,2,3

nohup python scripts/tools/train.py \
    --nprocs=4 \
    --train_type=seq \
    --data_root_path=/path/to/dataset \
    --param_path=/path/to/model \
    --batch_per_device=4 \
    --grad_accum_steps=8 \
    --warmup_ratio=0.05 \
    --max_lr=1e-4 \
    --max_grad_norm=1.0 \
    --adamw_beta1=0.9 \
    --adamw_beta2=0.95 \
    --adamw_weight_decay=0.01 \
    --window_size=2048 \
    --ckpt_interval=10000 \
    --ckpt_dir=./checkpoint \
    --random_seed=3407 \
    --label_smoothing=0.05 \
    > out.log 2> err.log &

Full parameter reference at params.md.

Document Update Time: 2026-05-24