AstrAI/assets/docs/params.md

4.3 KiB

Parameter Documentation

Training Parameters

Basic Parameters

Parameter Description Default
--train_type Training type (seq, sft, dpo) required
--data_root_path Dataset root directory required
--param_path Model parameters or checkpoint path required
--n_epoch Total training epochs 1
--batch_size Batch size 1
--accumulation_steps Gradient accumulation steps between optimizer steps 1

Learning Rate Scheduling

Parameter Description Default
--warmup_steps Warmup steps 1000
--max_lr Maximum learning rate (cosine decay after warmup) 3e-4
--max_grad_norm Maximum gradient norm for clipping 1.0

Optimizer (AdamW)

Parameter Description Default
--adamw_beta1 AdamW beta1 0.9
--adamw_beta2 AdamW beta2 0.95
--adamw_weight_decay AdamW weight decay 0.01

Data Loading

Parameter Description Default
--window_size Max input sequence length model config max_len
--stride Stride for sliding window over sequences None
--random_seed Random seed for reproducibility 3407
--num_workers DataLoader worker processes 4
--no_pin_memory Disable pin_memory (enabled by default) (flag)

Checkpoint & Resume

Parameter Description Default
--ckpt_interval Iterations between checkpoints 5000
--ckpt_dir Checkpoint save directory checkpoint
--start_epoch Resume from epoch (0 = from scratch) 0
--start_batch Resume from batch iteration 0

Distributed Training

Parameter Description Default
--nprocs Number of GPUs / processes 1
--device_type Device type cuda

Strategy-specific

Parameter Description Default Used by
--dpo_beta DPO beta value 0.1 dpo
--label_smoothing Label smoothing for cross-entropy loss 0.1 seq, sft

Usage Example

python scripts/tools/train.py \
  --train_type seq \
  --data_root_path /path/to/dataset \
  --param_path /path/to/model \
  --n_epoch 3 \
  --batch_size 4 \
  --accumulation_steps 8 \
  --max_lr 3e-4 \
  --warmup_steps 2000 \
  --max_grad_norm 1.0 \
  --ckpt_interval 5000 \
  --ckpt_dir ./checkpoints \
  --num_workers 4 \
  --nprocs 1 \
  --device_type cuda

Generation Parameters

GenerationRequest Parameters

Parameter Description Default Value
messages List of message dictionaries (role, content) required
temperature Sampling temperature (higher = more random) 1.0
top_p Nucleus sampling threshold 1.0
top_k Top-k sampling count 50
max_len Maximum generation length 1024
stream Whether to stream output False

Usage Example

import torch
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest

# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")

# Create engine with separate model and tokenizer
engine = InferenceEngine(
    model=model,
    tokenizer=tokenizer,
)

# Build request with messages format
request = GenerationRequest(
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello"},
    ],
    temperature=0.8,
    top_p=0.95,
    top_k=50,
    max_len=1024,
)

# Generate (streaming)
for token in engine.generate_with_request(request):
    print(token, end="", flush=True)

# Or use simple generate interface
result = engine.generate(
    prompt="Hello",
    stream=False,
    max_tokens=1024,
    temperature=0.8,
    top_p=0.95,
    top_k=50,
)

Generation Modes

Mode Description
stream=True Streaming output, yields token by token
stream=False Non-streaming output, returns complete result

Document Update Time: 2026-04-09