Compare commits
4 Commits
bc7c82977e
...
ca4e6b907c
| Author | SHA1 | Date |
|---|---|---|
|
|
ca4e6b907c | |
|
|
db99d8b254 | |
|
|
b98c9cefdc | |
|
|
283bcaf2ff |
52
README.md
52
README.md
|
|
@ -46,7 +46,7 @@
|
||||||
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
||||||
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
||||||
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
||||||
- 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets.
|
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
|
||||||
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
|
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
@ -68,46 +68,26 @@ pip install -e ".[dev]"
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
|
--data_root_path /path/to/dataset \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--batch_size 4 \
|
||||||
|
--accumulation_steps 8 \
|
||||||
|
--max_lr 3e-4 \
|
||||||
|
--warmup_steps 1000 \
|
||||||
|
--n_epoch 1
|
||||||
```
|
```
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
|
||||||
| `--param_path` | Model / checkpoint path | required |
|
|
||||||
| `--n_epoch` | Training epochs | 1 |
|
|
||||||
| `--batch_size` | Batch size | 1 |
|
|
||||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
|
||||||
| `--warmup_steps` | LR warmup steps | 1000 |
|
|
||||||
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
|
||||||
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
|
||||||
| `--random_seed` | Random seed | 3407 |
|
|
||||||
| `--num_workers` | DataLoader workers | 4 |
|
|
||||||
| `--window_size` | Max input sequence length | auto |
|
|
||||||
| `--stride` | Sequence stride | auto |
|
|
||||||
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
|
|
||||||
| `--dpo_beta` | DPO beta | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
|
|
||||||
| `--group_size` | GRPO group size | 4 |
|
|
||||||
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
|
|
||||||
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
|
||||||
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
|
||||||
| `--start_epoch` | Start epoch (for resume) | 0 |
|
|
||||||
| `--start_batch` | Start batch (for resume) | 0 |
|
|
||||||
| `--nprocs` | Number of GPUs | 1 |
|
|
||||||
| `--device_type` | Device type | cuda |
|
|
||||||
|
|
||||||
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
|
||||||
|
|
||||||
#### Generate Text
|
#### Generate Text
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--input_json_file /path/to/input.json \
|
||||||
|
--output_json_file /path/to/output.json
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
@ -140,8 +120,6 @@ docker compose --profile cpu up -d
|
||||||
|
|
||||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
||||||
|
|
||||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
|
||||||
|
|
||||||
#### Start HTTP Server
|
#### Start HTTP Server
|
||||||
|
|
||||||
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
|
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@
|
||||||
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
||||||
- 📦 **轻量**: 依赖少,部署简单。
|
- 📦 **轻量**: 依赖少,部署简单。
|
||||||
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
||||||
- 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。
|
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
|
||||||
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
|
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
|
||||||
|
|
||||||
### 快速开始
|
### 快速开始
|
||||||
|
|
@ -74,46 +74,26 @@ pip install -e ".[dev]"
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
|
--data_root_path /path/to/dataset \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--batch_size 4 \
|
||||||
|
--accumulation_steps 8 \
|
||||||
|
--max_lr 3e-4 \
|
||||||
|
--warmup_steps 1000 \
|
||||||
|
--n_epoch 1
|
||||||
```
|
```
|
||||||
|
|
||||||
| 参数 | 说明 | 默认值 |
|
完整参数列表见[参数说明](./params.md)。
|
||||||
|------|------|--------|
|
|
||||||
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 |
|
|
||||||
| `--data_root_path` | 数据集根目录 | 必填 |
|
|
||||||
| `--param_path` | 模型参数或断点路径 | 必填 |
|
|
||||||
| `--n_epoch` | 训练轮数 | 1 |
|
|
||||||
| `--batch_size` | 批次大小 | 1 |
|
|
||||||
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
|
||||||
| `--warmup_steps` | 预热步数 | 1000 |
|
|
||||||
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
|
||||||
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
|
||||||
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
|
|
||||||
| `--random_seed` | 随机种子 | 3407 |
|
|
||||||
| `--num_workers` | 数据加载线程数 | 4 |
|
|
||||||
| `--window_size` | 最大输入序列长度 | auto |
|
|
||||||
| `--stride` | 序列步长 | auto |
|
|
||||||
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
|
|
||||||
| `--dpo_beta` | DPO beta | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
|
|
||||||
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
|
|
||||||
| `--group_size` | GRPO 组大小 | 4 |
|
|
||||||
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
|
|
||||||
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
|
||||||
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
|
||||||
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
|
|
||||||
| `--start_batch` | 起始批次(用于断点续训) | 0 |
|
|
||||||
| `--nprocs` | GPU 数量 | 1 |
|
|
||||||
| `--device_type` | 设备类型 | cuda |
|
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md#training-parameters)。
|
|
||||||
|
|
||||||
#### 文本生成
|
#### 文本生成
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--input_json_file /path/to/input.json \
|
||||||
|
--output_json_file /path/to/output.json
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,11 @@ AstrAI adopts a modular design with the following main components:
|
||||||
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
||||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
|
||||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||||
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
|
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
|
||||||
|
|
||||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
|
||||||
|
|
||||||
## Data Flow Diagram
|
## Data Flow Diagram
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
|
|
@ -23,38 +21,36 @@ flowchart LR
|
||||||
subgraph A[Data Preparation]
|
subgraph A[Data Preparation]
|
||||||
direction TB
|
direction TB
|
||||||
A1[Raw Text] --> A2[AutoTokenizer]
|
A1[Raw Text] --> A2[AutoTokenizer]
|
||||||
A2 --> A3[Serialize to .h5 files]
|
A2 --> A3[Tokenized .h5 files]
|
||||||
A3 --> A4[BaseDataset]
|
A3 --> A4[BaseDataset]
|
||||||
A4 --> A5[ResumableDistributedSampler]
|
A4 --> A5[ResumableDistributedSampler]
|
||||||
A5 --> A6[PyTorch DataLoader]
|
A5 --> A6[DataLoader]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph B[Training]
|
subgraph B[Training]
|
||||||
direction TB
|
direction TB
|
||||||
B1[Batch Data] --> B2[TrainContextBuilder]
|
B1[DataLoader] --> B2[BaseStrategy]
|
||||||
B2 --> B3[TrainContext]
|
B2 --> B3[Transformer Forward]
|
||||||
B3 --> B4[BaseStrategy]
|
B3 --> B4[Loss + Backward]
|
||||||
B4 --> B5[Transformer]
|
B4 --> B5[Gradient Accumulation]
|
||||||
B5 --> B6[Compute Loss]
|
B5 -->|every accum_steps| B6[Optimizer Step]
|
||||||
B6 --> B7[Backward]
|
B6 --> B7[LR Scheduler]
|
||||||
B7 --> B8[Optimizer]
|
B7 -->|next batch| B2
|
||||||
B8 --> B9[LRScheduler]
|
B6 --> B8[CheckpointCallback]
|
||||||
B9 --> B10[CheckpointCallback]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph C[Inference]
|
subgraph C[Inference]
|
||||||
direction TB
|
direction TB
|
||||||
C1[Checkpoint] --> C2[AutoModel]
|
C1[Checkpoint] --> C2[AutoModel]
|
||||||
C2 --> C3[Transformer + Tokenizer]
|
C1 --> C3[AutoTokenizer]
|
||||||
C3 --> C4[GenerationRequest + apply_chat_template]
|
C2 --> C4[InferenceEngine]
|
||||||
C4 --> C5[InferenceEngine]
|
C3 --> C4
|
||||||
C5 --> C6[InferenceScheduler]
|
C4 --> C5[InferenceScheduler]
|
||||||
|
C5 --> C6[Transformer Forward]
|
||||||
C6 --> C7[sample]
|
C6 --> C7[sample]
|
||||||
C7 --> C8[Transformer Forward]
|
C7 --> C8{End?}
|
||||||
C8 --> C9[Paged KV Cache]
|
C8 -->|No| C6
|
||||||
C9 --> C10{End Condition?}
|
C8 -->|Yes| C9[Generated Text]
|
||||||
C10 -->|No| C8
|
|
||||||
C10 -->|Yes| C11[Output Text]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
A --> B
|
A --> B
|
||||||
|
|
@ -65,215 +61,177 @@ flowchart LR
|
||||||
|
|
||||||
### 1. Serialization (`astrai/serialization.py`)
|
### 1. Serialization (`astrai/serialization.py`)
|
||||||
|
|
||||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
|
||||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
|
||||||
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
- **`Checkpoint`**: Encapsulates model state dict + epoch + iteration; uses safetensors
|
||||||
|
|
||||||
### 2. Dataset Module
|
### 2. Dataset Module
|
||||||
|
|
||||||
#### 2.1 Dataset (`dataset.py`)
|
#### 2.1 Dataset (`dataset.py`)
|
||||||
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
- **`BaseDataset`**: Abstract base class for windowed sequence sampling
|
||||||
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
|
||||||
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
|
||||||
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO)
|
||||||
|
|
||||||
#### 2.2 Sampler (`sampler.py`)
|
#### 2.2 Sampler (`sampler.py`)
|
||||||
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
|
||||||
- Records current epoch and iteration position, enabling training resume from breakpoints
|
|
||||||
- Supports shuffle and drop_last options
|
|
||||||
|
|
||||||
### 3. Model Module
|
### 3. Model Module
|
||||||
|
|
||||||
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
#### 3.1 Transformer / AutoModel
|
||||||
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
- **`AutoModel`**: Base class with `from_pretrained()` / `save_pretrained()`
|
||||||
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
- **`Transformer`**: Decoder-only architecture, registered via `@AutoModel.register('transformer')`
|
||||||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
- Embedding → N×DecoderBlock → RMSNorm → Linear lm_head
|
||||||
- Supports weight tying (`tie_weight=True`) to reduce parameter count
|
- RoPE position encoding, optional weight tying
|
||||||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
|
||||||
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
|
||||||
|
|
||||||
#### 3.2 Submodules (`module.py`)
|
#### 3.2 Submodules (`module.py`)
|
||||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
||||||
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
||||||
- **`GQA`**: Grouped Query Attention implementation
|
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
||||||
- **`MLA`**: Multi-Latent Attention implementation (like Qwen2-VL)
|
- **`RotaryEmbedding`**: RoPE cos/sin cache
|
||||||
- **`MLP`**: Feed-forward network with SiLU activation and gated mechanism
|
- **`RMSNorm`**: Layer normalization
|
||||||
- **`RMSNorm`**: Layer normalization variant
|
|
||||||
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
|
||||||
|
|
||||||
### 4. Training Module
|
### 4. Training Module
|
||||||
|
|
||||||
#### 4.1 Training Context (`train_context.py`)
|
#### 4.1 Training Context (`train_context.py`)
|
||||||
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
|
- **`TrainContext`**: Dataclass holding model, optimizer, dataloader, strategy, scheduler, checkpoint state
|
||||||
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
|
- **`TrainContextBuilder`**: Builder pattern — takes checkpoint for resume, builds all components
|
||||||
|
|
||||||
#### 4.2 Trainer (`trainer.py`)
|
#### 4.2 Trainer (`trainer.py`)
|
||||||
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
|
|
||||||
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
|
The training loop is nested: **epoch** → **batch** (with step phase interspersed):
|
||||||
- Training steps include:
|
|
||||||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
```
|
||||||
|
on_train_begin
|
||||||
|
on_epoch_begin
|
||||||
|
for each batch:
|
||||||
|
if iteration % accumulation_steps == 0: ← step phase
|
||||||
|
on_step_begin → optimizer.step() → zero_grad → on_step_end
|
||||||
|
← batch phase
|
||||||
|
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
on_epoch_end
|
||||||
|
on_train_end
|
||||||
|
```
|
||||||
|
|
||||||
|
Key points:
|
||||||
|
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches)
|
||||||
|
- `on_batch_*` wraps loss computation (fires every batch)
|
||||||
|
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch
|
||||||
|
- `GradientClippingCallback` fires on `on_step_begin`
|
||||||
|
|
||||||
#### 4.3 Strategy (`strategy.py`)
|
#### 4.3 Strategy (`strategy.py`)
|
||||||
- **`BaseStrategy`**: Defines training strategy interface
|
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
||||||
- **`SEQStrategy`**: Standard next-token prediction training
|
- **`SFTStrategy`**: Supervised fine-tuning with loss masking
|
||||||
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
- **`DPOStrategy`**: Direct Preference Optimization with reference model
|
||||||
- **`DPOStrategy`**: Direct Preference Optimization
|
- **`GRPOStrategy`**: Group Relative Policy Optimization with clipped ratio
|
||||||
- **`GRPOStrategy`**: Group Relative Policy Optimization
|
|
||||||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
|
||||||
- Created dynamically by `StrategyFactory` according to configuration
|
|
||||||
|
|
||||||
#### 4.4 Scheduler (`schedule.py`)
|
#### 4.4 Scheduler (`schedule.py`)
|
||||||
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
- **`CosineScheduler`**: Cosine decay + linear warmup
|
||||||
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
- **`SGDRScheduler`**: Cosine annealing with warm restarts
|
||||||
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
- Created by `SchedulerFactory` and bound to optimizer
|
||||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
|
||||||
- Scheduler is automatically created according to configuration and bound to optimizer
|
|
||||||
|
|
||||||
#### 4.5 Callbacks (`train_callback.py`)
|
#### 4.5 Callbacks
|
||||||
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
||||||
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
- **`ProgressBarCallback`**: tqdm progress display
|
||||||
- **`ProgressBarCallback`**: Displays training progress
|
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
||||||
- **`MetricLoggerCallback`**: Logs training metrics to JSON files
|
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin`
|
||||||
- **`GradientClippingCallback`**: Clips gradient norms
|
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
|
||||||
- **`SchedulerCallback`**: Steps learning rate scheduler
|
|
||||||
|
|
||||||
#### 4.6 Metric Utility (`metric_util.py`)
|
### 5. Inference Module
|
||||||
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
|
|
||||||
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
|
|
||||||
|
|
||||||
### 5. Factory Module
|
#### 5.1 Inference Engine (`engine.py`)
|
||||||
|
- **`InferenceEngine`**: Facade over scheduler; provides `generate()`, `generate_with_request()`, `generate_async()`
|
||||||
|
- Accepts `prompt: str | List[str]`, returns generator (stream) or string (non-stream)
|
||||||
|
|
||||||
#### 5.1 Registry and BaseFactory (`factory.py`)
|
#### 5.2 Scheduler 4-Phase Loop (`scheduler.py`)
|
||||||
- **`Registry`**: Flexible registry for component classes with category and priority support
|
|
||||||
- **`BaseFactory`**: Generic factory class for component registration and creation
|
|
||||||
- Supports decorator-based registration pattern for extensible components
|
|
||||||
- Provides methods for registration, retrieval, and listing with filtering
|
|
||||||
|
|
||||||
### 6. Parallel Module
|
Background thread runs continuously:
|
||||||
|
|
||||||
#### 6.1 Setup (`setup.py`)
|
```
|
||||||
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
1. Cleanup → Remove finished tasks, free KV cache pages
|
||||||
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
2. Refill → Pop from waiting_queue, alloc pages, add to active
|
||||||
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
3. Prefill → Group active tasks by prompt_len, run full forward pass
|
||||||
- **`get_rank`**: Returns current process rank in distributed group
|
4. Decode → Pick largest same-position group, run single-token forward
|
||||||
- **`get_world_size`**: Returns total number of processes in distributed group
|
```
|
||||||
- **`get_current_device`**: Returns current device from environment
|
|
||||||
|
|
||||||
#### 6.2 Parallel Layers (`module.py`)
|
- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
|
||||||
- **`ParallelModel`**: Base class for parallel models with process group
|
- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
|
||||||
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
- **`CacheView`**: Batch view bundling cache + page table for attention layers
|
||||||
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
- **`sample()`**: Temperature → top-k → top-p → multinomial
|
||||||
|
|
||||||
### 7. Inference Module
|
#### 5.3 Server (`server.py`)
|
||||||
|
- FastAPI with OpenAI `/v1/chat/completions` and Anthropic `/v1/messages` endpoints
|
||||||
|
- Streaming via SSE, health check at `/health`, stats at `/stats`
|
||||||
|
|
||||||
#### 7.1 Inference Engine (`engine.py`)
|
### 6. Tokenizer Module
|
||||||
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
|
|
||||||
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
|
|
||||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
|
||||||
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
|
|
||||||
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
|
||||||
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
|
||||||
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
|
|
||||||
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
|
||||||
- Uses separate model and tokenizer initialization for flexibility
|
|
||||||
|
|
||||||
#### 7.2 Cache (`cache.py`)
|
- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template`
|
||||||
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
|
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
|
||||||
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
|
|
||||||
|
|
||||||
#### 7.3 Scheduler (`scheduler.py`)
|
### 7. Factory & Parallel
|
||||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
|
||||||
- **`TaskStatus`**: Task state enumeration
|
|
||||||
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
|
|
||||||
- Uses `PagedCache` for paged KV cache management with page table indirection
|
|
||||||
- Continuous batching: new requests can join at any time, completed requests release pages immediately
|
|
||||||
|
|
||||||
#### 7.4 Server (`server.py`)
|
- **`Registry` / `BaseFactory`**: Decorator-based component registration
|
||||||
- FastAPI-based HTTP inference server
|
- **`spawn_parallel_fn`**: Multi-process DDP launcher with NCCL backend
|
||||||
- OpenAI-compatible `/v1/chat/completions` endpoint
|
- **`ParallelModel` / `ColumnParallelLinear` / `RowParallelLinear`**: Tensor model parallelism
|
||||||
- Health check and statistics endpoints
|
|
||||||
- Supports both streaming and non-streaming responses
|
|
||||||
|
|
||||||
### 8. Tokenizer Module
|
## Training Data Flow — Detailed Steps
|
||||||
|
|
||||||
#### 8.1 Tokenizer (`tokenizer.py`)
|
|
||||||
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
|
||||||
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
|
||||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
|
||||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
|
||||||
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
|
||||||
|
|
||||||
#### 8.2 Chat Template (`chat_template.py`)
|
|
||||||
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
|
||||||
- Handles multi-role message formatting (system, user, assistant)
|
|
||||||
- Supports dynamic prompts and generation prompts
|
|
||||||
|
|
||||||
## Training Data Flow - Detailed Steps
|
|
||||||
|
|
||||||
1. **Data Preparation**
|
1. **Data Preparation**
|
||||||
- Raw text is converted to token ID sequences through AutoTokenizer
|
- Raw text → token IDs via `AutoTokenizer.encode()`
|
||||||
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
|
- Save as `.h5` files (groups of tensor lists per data key)
|
||||||
- Files can contain multiple segments, each segment corresponds to a tensor
|
|
||||||
|
|
||||||
2. **Dataset Loading**
|
2. **Dataset Loading**
|
||||||
- `BaseDataset`'s `load` method calls `load_h5`, obtaining `segments` dictionary
|
- `BaseDataset.load()` calls `load_h5()`, builds `MultiSegmentFetcher`
|
||||||
- Create `MultiSegmentFetcher` to manage data for multiple keys
|
- Sliding window of `window_size` with `stride` determines sample boundaries
|
||||||
- Calculate total sample count, and determine start/end indices for each sample based on window size and stride
|
|
||||||
|
|
||||||
3. **Sampling and Batch Loading**
|
3. **Sampling & Batching**
|
||||||
- `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position
|
- `ResumableDistributedSampler` produces shuffled index sequences
|
||||||
- PyTorch `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
|
- `DataLoader` fetches `[batch_size, window_size]` tensors via `__getitem__`
|
||||||
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
|
|
||||||
|
|
||||||
4. **Strategy Forward and Loss Calculation**
|
4. **Strategy Forward**
|
||||||
- Batch data is passed to strategy (such as `SEQStrategy`)
|
- Strategy receives batch, calls `Transformer.forward()` for logits
|
||||||
- Strategy internally calls `Transformer` model, obtaining logits
|
- Computes task-specific loss (cross-entropy, DPO, GRPO)
|
||||||
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
|
|
||||||
- Return loss tensor
|
|
||||||
|
|
||||||
5. **Backpropagation and Optimization**
|
5. **Backward & Accumulation**
|
||||||
- Loss is normalized by dividing by accumulation steps, then `loss.backward()` is executed
|
- `loss = raw_loss / accumulation_steps`
|
||||||
- After accumulating `accumulation_steps` batches, optimizer `step()` and `zero_grad()` are executed
|
- `loss.backward()` accumulates gradients
|
||||||
- Learning rate scheduler updates learning rate after each step
|
- Every `accumulation_steps` batches: `optimizer.step()` → `zero_grad()`
|
||||||
|
- Every batch: `scheduler.step()` updates learning rate
|
||||||
|
|
||||||
6. **Checkpoint Saving**
|
6. **Checkpoint**
|
||||||
- `CheckpointCallback` saves checkpoints at set intervals
|
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
|
||||||
- Checkpoints contain model state dict, current epoch, iteration, and other metadata
|
- Does NOT save optimizer/scheduler state (resume resets those)
|
||||||
- Saved in safetensors format, ensuring safety and efficiency
|
|
||||||
|
|
||||||
## Inference Data Flow - Detailed Steps
|
## Inference Data Flow — Detailed Steps
|
||||||
|
|
||||||
1. **Model Loading**
|
1. **Model Loading**
|
||||||
- Load `Transformer` model from checkpoint via `AutoModel.from_pretrained()`
|
- `AutoModel.from_pretrained(path)` loads weights from safetensors
|
||||||
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
|
- `torch.inference_mode()` wraps generation
|
||||||
|
|
||||||
2. **Prompt Construction and Encoding**
|
2. **Prompt Construction**
|
||||||
- User messages (list of dict with role and content) are converted to ChatML format string through `apply_chat_template` method in tokenizer
|
- Messages → `apply_chat_template(messages, tokenize=False)` → prompt string
|
||||||
- Tokenizer encodes prompt string to token ID sequence `input_ids`
|
- `tokenizer.encode(prompt)` → token IDs (truncated to `max_prompt_len`)
|
||||||
- For batch generation, use `pad_sequence` for padding
|
|
||||||
|
|
||||||
3. **Autoregressive Generation Loop**
|
3. **Continuous Batching Loop**
|
||||||
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
|
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
|
||||||
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
|
- **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
|
||||||
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
|
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
|
||||||
- Input last token ID to model, obtain `logits`
|
- **Decode**: Pick position group with most tasks, single-token forward:
|
||||||
- Apply `sample()` (temperature, top-k, top-p) to `logits`
|
- Model forward → `logits` → `sample()` → next token ID
|
||||||
- Sample next token ID from the processed distribution
|
- Append to `output_ids`, update `output_tokens`
|
||||||
- Write new KV entries into paged cache; allocate additional pages as needed
|
- `_maybe_alloc_page()` grows page table as needed
|
||||||
- For streaming generation, yield each token to caller immediately via `stream_callback`
|
- `stream_callback(token)` for streaming clients
|
||||||
|
|
||||||
4. **Decoding and Output**
|
4. **Output**
|
||||||
- Decode generated token ID sequence to text through tokenizer
|
- `tokenizer.decode(output_ids)` → text
|
||||||
- Remove special tokens, return plain text response
|
- Return to caller (streaming: token-by-token; non-streaming: complete string)
|
||||||
|
|
||||||
## Checkpoint and Serialization
|
## Checkpoint & Serialization
|
||||||
|
|
||||||
- **Training Checkpoint**: Saves model parameters, optimizer state, scheduler state, current epoch and iteration
|
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted.
|
||||||
- **Model Parameters**: Supports safetensors format, automatically handles special logic like weight tying during loading
|
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
|
||||||
- **Dataset Serialization**: HDF5 format supports efficient random access and shared memory, suitable for large-scale pre-training data
|
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
|
||||||
|
|
||||||
## Summary
|
> Document Update Time: 2026-05-09
|
||||||
|
|
||||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
|
||||||
|
|
||||||
> Document Update Time: 2026-04-09
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,6 @@ classDiagram
|
||||||
+str master_port
|
+str master_port
|
||||||
+Callable parallel_wrapper
|
+Callable parallel_wrapper
|
||||||
+Callable state_dict_fn
|
+Callable state_dict_fn
|
||||||
+List[int] device_ids
|
|
||||||
+str device_type
|
+str device_type
|
||||||
+dict extra_kwargs
|
+dict extra_kwargs
|
||||||
+validate()
|
+validate()
|
||||||
|
|
@ -99,8 +98,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
+int start_epoch
|
+int epoch
|
||||||
+int start_iter
|
+int iter
|
||||||
}
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
|
|
@ -124,7 +123,7 @@ classDiagram
|
||||||
namespace model {
|
namespace model {
|
||||||
class AutoModel {
|
class AutoModel {
|
||||||
+ModelConfig config
|
+ModelConfig config
|
||||||
+Dict _registry
|
+Registry _registry
|
||||||
+register(model_type) decorator
|
+register(model_type) decorator
|
||||||
+get_model_class(model_type) Type
|
+get_model_class(model_type) Type
|
||||||
+from_pretrained(path, disable_random_init) nn.Module
|
+from_pretrained(path, disable_random_init) nn.Module
|
||||||
|
|
@ -139,7 +138,7 @@ classDiagram
|
||||||
+ModuleList layers
|
+ModuleList layers
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
+Linear lm_head
|
+Linear lm_head
|
||||||
+forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
|
+forward(input_ids, input_mask, paged_cache, start_pos) Dict
|
||||||
+load_state_dict(state_dict)
|
+load_state_dict(state_dict)
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
@ -149,7 +148,7 @@ classDiagram
|
||||||
+RMSNorm input_norm
|
+RMSNorm input_norm
|
||||||
+MLP mlp
|
+MLP mlp
|
||||||
+RMSNorm post_attention_norm
|
+RMSNorm post_attention_norm
|
||||||
+forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
|
+forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class GQA {
|
class GQA {
|
||||||
|
|
@ -158,18 +157,20 @@ classDiagram
|
||||||
+int head_dim
|
+int head_dim
|
||||||
+Linear q_proj, k_proj, v_proj, o_proj
|
+Linear q_proj, k_proj, v_proj, o_proj
|
||||||
+RMSNorm q_norm, k_norm
|
+RMSNorm q_norm, k_norm
|
||||||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class MLA {
|
class MLA {
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
+Linear q_a_proj, q_b_proj, q_c_proj
|
+int kv_lora_rank
|
||||||
+Linear kv_a_proj, kv_b_proj, kv_c_proj
|
+int qk_nope_head_dim
|
||||||
|
+int qk_rope_head_dim
|
||||||
|
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||||
+Linear o_proj
|
+Linear o_proj
|
||||||
+RMSNorm q_norm, k_norm
|
+RMSNorm kv_norm
|
||||||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class MLP {
|
class MLP {
|
||||||
|
|
@ -204,7 +205,7 @@ classDiagram
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+List[str] stop_ids
|
+List[int] stop_ids
|
||||||
+int bos_id
|
+int bos_id
|
||||||
+int eos_id
|
+int eos_id
|
||||||
+int pad_id
|
+int pad_id
|
||||||
|
|
@ -220,7 +221,7 @@ classDiagram
|
||||||
|
|
||||||
class ChatTemplate {
|
class ChatTemplate {
|
||||||
+String template_str
|
+String template_str
|
||||||
+render(messages, add_generation_prompt) str
|
+render(messages, system_prompt, **extra_variables) str
|
||||||
+from_string(template) ChatTemplate
|
+from_string(template) ChatTemplate
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -267,8 +268,6 @@ classDiagram
|
||||||
class TrainContextBuilder {
|
class TrainContextBuilder {
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
+with_checkpoint(checkpoint) TrainContextBuilder
|
+with_checkpoint(checkpoint) TrainContextBuilder
|
||||||
+with_dataloader() TrainContextBuilder
|
|
||||||
+with_strategy() TrainContextBuilder
|
|
||||||
+build() TrainContext
|
+build() TrainContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -454,7 +453,7 @@ classDiagram
|
||||||
+float arrival_time
|
+float arrival_time
|
||||||
+float finish_time
|
+float finish_time
|
||||||
+Callable stream_callback
|
+Callable stream_callback
|
||||||
+next_pos() int
|
+int next_pos
|
||||||
+is_finished(stop_ids) bool
|
+is_finished(stop_ids) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -506,15 +505,10 @@ classDiagram
|
||||||
+sample(logits, filter_value) Tensor
|
+sample(logits, filter_value) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class Server {
|
|
||||||
+start()
|
|
||||||
+predict(request)
|
|
||||||
}
|
|
||||||
|
|
||||||
class _Result {
|
class _Result {
|
||||||
+List[str] tokens
|
+List[str] tokens
|
||||||
+List[str] results
|
+List[str] results
|
||||||
+List[bool] done_flags
|
+List[bool] _done
|
||||||
+append(token, idx)
|
+append(token, idx)
|
||||||
+get_results() List[str]
|
+get_results() List[str]
|
||||||
+pop_all() List[str]
|
+pop_all() List[str]
|
||||||
|
|
@ -539,9 +533,9 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class ParallelSetup {
|
class ParallelFunctions {
|
||||||
+spawn_parallel_fn(fn, nprocs)
|
+spawn_parallel_fn(fn, nprocs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParallelModel {
|
class ParallelModel {
|
||||||
|
|
@ -601,24 +595,19 @@ classDiagram
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
SamplingPipeline --> BaseSamplingStrategy : composes
|
SamplingPipeline --> BaseSamplingStrategy : composes
|
||||||
Server --> InferenceEngine : uses
|
|
||||||
Server --> ChatMessage : uses
|
|
||||||
Server --> ChatCompletionRequest : uses
|
|
||||||
ParallelSetup --> Trainer : enables
|
|
||||||
BaseDataset <|-- SEQDataset
|
BaseDataset <|-- SEQDataset
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
BaseDataset <|-- GRPODataset
|
BaseDataset <|-- GRPODataset
|
||||||
DatasetFactory ..> BaseDataset : creates
|
DatasetFactory ..> BaseDataset : creates
|
||||||
BaseSegmentFetcher --> MultiSegmentFetcher : used by
|
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
||||||
MultiSegmentFetcher --> BaseDataset : used by
|
BaseDataset --> MultiSegmentFetcher : uses
|
||||||
AutoModel <|-- Transformer
|
AutoModel <|-- Transformer
|
||||||
AutoModel --> ModelConfig : contains
|
AutoModel --> ModelConfig : contains
|
||||||
Transformer --> DecoderBlock : uses
|
Transformer --> DecoderBlock : uses
|
||||||
Transformer --> RotaryEmbedding : uses
|
Transformer --> RotaryEmbedding : uses
|
||||||
Transformer --> Embedding : uses
|
Transformer --> Embedding : uses
|
||||||
DecoderBlock --> GQA : uses
|
DecoderBlock --> GQA : uses
|
||||||
DecoderBlock --> MLA : uses
|
|
||||||
DecoderBlock --> MLP : uses
|
DecoderBlock --> MLP : uses
|
||||||
DecoderBlock --> RMSNorm : uses
|
DecoderBlock --> RMSNorm : uses
|
||||||
TrainContextBuilder --> ResumableDistributedSampler : creates
|
TrainContextBuilder --> ResumableDistributedSampler : creates
|
||||||
|
|
@ -647,7 +636,7 @@ classDiagram
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||||
|
|
||||||
### Design Patterns
|
### Design Patterns
|
||||||
|
|
@ -658,7 +647,7 @@ classDiagram
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Singleton** | `TrainContext` | Training process global state management |
|
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
||||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||||
|
|
@ -672,8 +661,8 @@ classDiagram
|
||||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
||||||
|
|
@ -717,12 +706,6 @@ $$
|
||||||
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
|
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
|
||||||
$$
|
$$
|
||||||
|
|
||||||
In this implementation, an off-policy approach is used ($\pi_\theta = \pi_{\text{ref}}$), and the policy loss simplifies to:
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{policy}} = -\mathbb{E}[A]
|
|
||||||
$$
|
|
||||||
|
|
||||||
The KL divergence term uses mean squared error approximation:
|
The KL divergence term uses mean squared error approximation:
|
||||||
|
|
||||||
$$
|
$$
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
### 1. Model Architecture
|
### 1. Model Architecture
|
||||||
|
|
||||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||||
|
|
||||||
The model now uses the **AutoModel** base class for flexible loading and saving:
|
The model now uses the **AutoModel** base class for flexible loading and saving:
|
||||||
|
|
||||||
|
|
@ -48,14 +48,15 @@ flowchart TB
|
||||||
S --> T[+]
|
S --> T[+]
|
||||||
H --> T
|
H --> T
|
||||||
T --> U[RMSNorm]
|
T --> U[RMSNorm]
|
||||||
U --> V[Linear]
|
U --> V["Linear (gate)"]
|
||||||
V --> W[SiLU]
|
U --> W["Linear (up)"]
|
||||||
V --> X[×]
|
V --> X[SiLU]
|
||||||
W --> X
|
X --> Y[×]
|
||||||
X --> Y[Linear]
|
W --> Y
|
||||||
Y --> Z[+]
|
Y --> Z["Linear (down)"]
|
||||||
T --> Z
|
Z --> AA[+]
|
||||||
Z --> AA[x']
|
T --> AA
|
||||||
|
AA --> BB[x']
|
||||||
end
|
end
|
||||||
|
|
||||||
classDef main fill:#e6f3ff,stroke:#0066cc;
|
classDef main fill:#e6f3ff,stroke:#0066cc;
|
||||||
|
|
@ -168,8 +169,6 @@ from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
engine = InferenceEngine(
|
engine = InferenceEngine(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_batch_size=8,
|
|
||||||
max_seq_len=4096,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use GenerationRequest with messages format
|
# Use GenerationRequest with messages format
|
||||||
|
|
@ -222,12 +221,11 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|-----------|------|---------|-------------|
|
|-----------|------|---------|-------------|
|
||||||
| `messages` | List[dict] | Required | Chat messages with role and content |
|
| `messages` | List[dict] | Required | Chat messages with role and content |
|
||||||
| `temperature` | float | 0.8 | Sampling temperature (0.0-2.0) |
|
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
|
||||||
| `top_p` | float | 0.95 | Nucleus sampling threshold |
|
| `top_p` | float | 1.0 | Nucleus sampling threshold |
|
||||||
| `top_k` | int | 50 | Top-k sampling parameter |
|
| `top_k` | int | 50 | Top-k sampling parameter |
|
||||||
| `max_tokens` | int | 2048 | Maximum tokens to generate |
|
| `max_tokens` | int | 1024 | Maximum tokens to generate |
|
||||||
| `stream` | bool | false | Enable streaming response |
|
| `stream` | bool | false | Enable streaming response |
|
||||||
| `system_prompt` | str | None | System prompt override |
|
|
||||||
|
|
||||||
**Response (non-streaming):**
|
**Response (non-streaming):**
|
||||||
```json
|
```json
|
||||||
|
|
@ -242,7 +240,12 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
|
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
|
||||||
"finish_reason": "stop"
|
"finish_reason": "stop"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 20,
|
||||||
|
"completion_tokens": 15,
|
||||||
|
"total_tokens": 35
|
||||||
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -262,9 +265,6 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
|
||||||
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
||||||
|
|
||||||
### Health Check
|
|
||||||
|
|
||||||
|
|
||||||
### Anthropic-Compatible Endpoint
|
### Anthropic-Compatible Endpoint
|
||||||
|
|
||||||
The server also provides an Anthropic-compatible endpoint at `/v1/messages`:
|
The server also provides an Anthropic-compatible endpoint at `/v1/messages`:
|
||||||
|
|
@ -325,10 +325,10 @@ Monitor server and model status:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:8000/health
|
curl http://localhost:8000/health
|
||||||
# {"status": "ok", "model_loaded": true, "engine_ready": true}
|
# {"status": "ok", "model_loaded": true}
|
||||||
|
|
||||||
curl http://localhost:8000/stats
|
curl http://localhost:8000/stats
|
||||||
# {"requests_total": 10, "tokens_generated": 5000, ...}
|
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-04-09
|
> Document Update Time: 2026-04-09
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
|
|
@ -61,6 +61,10 @@
|
||||||
|-----------|-------------|---------|---------|
|
|-----------|-------------|---------|---------|
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
||||||
|
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||||
|
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||||
|
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||||
|
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
||||||
|
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -74,9 +74,6 @@ class TrainConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
device_ids: Optional[List[int]] = field(
|
|
||||||
default=None, metadata={"help": "Device ids for distributed training."}
|
|
||||||
)
|
|
||||||
device_type: str = field(
|
device_type: str = field(
|
||||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Layers:
|
Layers:
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||||
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
||||||
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
|
- cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
|
||||||
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -408,13 +408,14 @@ class InferenceEngine:
|
||||||
Single string for one prompt, list of strings for batch.
|
Single string for one prompt, list of strings for batch.
|
||||||
"""
|
"""
|
||||||
result = _Result(count=len(prompts))
|
result = _Result(count=len(prompts))
|
||||||
|
task_ids = []
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
for i, p in enumerate(prompts):
|
||||||
|
|
||||||
def make_cb(idx):
|
def make_cb(idx):
|
||||||
return lambda tok: result.append(tok, idx)
|
return lambda tok: result.append(tok, idx)
|
||||||
|
|
||||||
self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=p,
|
prompt=p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
@ -422,8 +423,14 @@ class InferenceEngine:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=make_cb(i),
|
stream_callback=make_cb(i),
|
||||||
)
|
)
|
||||||
|
task_ids.append(task_id)
|
||||||
|
|
||||||
|
while result._completed < result._total:
|
||||||
|
result.wait(timeout=1.0)
|
||||||
|
|
||||||
|
for task_id in task_ids:
|
||||||
|
self.scheduler.remove_task(task_id)
|
||||||
|
|
||||||
result.wait()
|
|
||||||
res = result.get_results()
|
res = result.get_results()
|
||||||
return res if is_batch else res[0]
|
return res if is_batch else res[0]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ class Task:
|
||||||
self.arrival_time = time.time()
|
self.arrival_time = time.time()
|
||||||
self.finish_time: Optional[float] = None
|
self.finish_time: Optional[float] = None
|
||||||
self.stream_callback = stream_callback
|
self.stream_callback = stream_callback
|
||||||
|
self._pages_freed: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_pos(self) -> int:
|
def next_pos(self) -> int:
|
||||||
|
|
@ -104,7 +105,9 @@ class InferenceScheduler:
|
||||||
n_kv_heads = config.n_kv_heads
|
n_kv_heads = config.n_kv_heads
|
||||||
head_dim = config.dim // config.n_heads
|
head_dim = config.dim // config.n_heads
|
||||||
n_layers = config.n_layers
|
n_layers = config.n_layers
|
||||||
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
|
n_pages = (
|
||||||
|
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||||
|
) // page_size
|
||||||
|
|
||||||
self.page_cache = PagedCache(
|
self.page_cache = PagedCache(
|
||||||
n_layers,
|
n_layers,
|
||||||
|
|
@ -167,9 +170,11 @@ class InferenceScheduler:
|
||||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
|
||||||
for task in removed_active:
|
for task in removed_active:
|
||||||
self._free_pages(task.page_table)
|
if not task._pages_freed:
|
||||||
task.page_table.clear()
|
self._free_pages(task.page_table)
|
||||||
task.n_pages = 0
|
task.page_table.clear()
|
||||||
|
task.n_pages = 0
|
||||||
|
task._pages_freed = True
|
||||||
|
|
||||||
def _free_pages(self, indices: List[int]) -> None:
|
def _free_pages(self, indices: List[int]) -> None:
|
||||||
for idx in indices:
|
for idx in indices:
|
||||||
|
|
@ -185,9 +190,11 @@ class InferenceScheduler:
|
||||||
self._total_tokens += task.output_tokens
|
self._total_tokens += task.output_tokens
|
||||||
|
|
||||||
for task in finished:
|
for task in finished:
|
||||||
self._free_pages(task.page_table)
|
if not task._pages_freed:
|
||||||
task.page_table.clear()
|
self._free_pages(task.page_table)
|
||||||
task.n_pages = 0
|
task.page_table.clear()
|
||||||
|
task.n_pages = 0
|
||||||
|
task._pages_freed = True
|
||||||
|
|
||||||
self.active_tasks = [
|
self.active_tasks = [
|
||||||
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
||||||
|
|
@ -274,6 +281,9 @@ class InferenceScheduler:
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
|
for t in tasks:
|
||||||
|
self._maybe_alloc_page(t, start_pos)
|
||||||
|
|
||||||
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,7 @@ class AutoModel(nn.Module):
|
||||||
cls,
|
cls,
|
||||||
path: Union[str, Path],
|
path: Union[str, Path],
|
||||||
disable_random_init: bool = True,
|
disable_random_init: bool = True,
|
||||||
|
strict: bool = True,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
|
|
||||||
model_path = Path(path)
|
model_path = Path(path)
|
||||||
|
|
@ -106,7 +107,7 @@ class AutoModel(nn.Module):
|
||||||
weights_path = model_path / "model.safetensors"
|
weights_path = model_path / "model.safetensors"
|
||||||
if weights_path.exists():
|
if weights_path.exists():
|
||||||
state_dict = st.load_file(str(weights_path))
|
state_dict = st.load_file(str(weights_path))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -34,7 +34,6 @@ def setup_parallel(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -45,15 +44,10 @@ def setup_parallel(
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
if device_ids is None:
|
device_id = torch.device(device_type, rank)
|
||||||
device_ids = [i for i in range(world_size)]
|
|
||||||
|
|
||||||
rank = device_ids[rank % len(device_ids)]
|
|
||||||
device_id = torch.device(device_type, device_ids[rank])
|
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ["MASTER_PORT"] = master_port
|
os.environ["MASTER_PORT"] = master_port
|
||||||
|
|
||||||
os.environ["LOCAL_RANK"] = str(rank)
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
@ -103,7 +97,6 @@ def wrapper_spawn_func(
|
||||||
master_addr: str,
|
master_addr: str,
|
||||||
master_port: str,
|
master_port: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
device_ids: List[int],
|
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
|
|
@ -115,7 +108,6 @@ def wrapper_spawn_func(
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
device_ids=device_ids,
|
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
|
|
@ -131,7 +123,6 @@ def spawn_parallel_fn(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# clear environment variables
|
# clear environment variables
|
||||||
|
|
@ -147,8 +138,9 @@ def spawn_parallel_fn(
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
device_ids = device_ids or [0]
|
device_id = torch.device(device_type, 0)
|
||||||
device_id = torch.device(device_type, device_ids[0])
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
@ -160,7 +152,6 @@ def spawn_parallel_fn(
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
device_type,
|
device_type,
|
||||||
device_ids,
|
|
||||||
func,
|
func,
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
|
|
@ -54,10 +54,12 @@ class Checkpoint:
|
||||||
state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
epoch: int = 0,
|
epoch: int = 0,
|
||||||
iteration: int = 0,
|
iteration: int = 0,
|
||||||
|
extra: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
self.state_dict = state_dict
|
self.state_dict = state_dict
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
|
self.extra = extra or {}
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
|
|
@ -77,6 +79,8 @@ class Checkpoint:
|
||||||
json.dump(meta, f, indent=2)
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||||
|
if self.extra:
|
||||||
|
torch.save(self.extra, save_path / "extra.pt")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -99,8 +103,14 @@ class Checkpoint:
|
||||||
|
|
||||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||||
|
|
||||||
|
extra = None
|
||||||
|
extra_path = save_path / "extra.pt"
|
||||||
|
if extra_path.exists():
|
||||||
|
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta["epoch"],
|
epoch=meta["epoch"],
|
||||||
iteration=meta["iteration"],
|
iteration=meta["iteration"],
|
||||||
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,11 @@ class AutoTokenizer:
|
||||||
save_path: Path to save the tokenizer
|
save_path: Path to save the tokenizer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||||
|
)
|
||||||
|
|
||||||
save_path = Path(save_path)
|
save_path = Path(save_path)
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -121,11 +121,13 @@ class CheckpointCallback(TrainCallback):
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
self.state_dict_fn = state_dict_fn
|
||||||
|
self.save_extra_fn = save_extra_fn
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -139,8 +141,12 @@ class CheckpointCallback(TrainCallback):
|
||||||
else context.model.state_dict()
|
else context.model.state_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
|
state_dict=state_dict,
|
||||||
|
epoch=context.epoch,
|
||||||
|
iteration=context.iteration,
|
||||||
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self
|
from typing import Callable, Optional, Self
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -32,9 +32,14 @@ class TrainContext:
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, config: TrainConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TrainConfig,
|
||||||
|
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||||
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._checkpoint: Optional[Checkpoint] = None
|
self._checkpoint: Optional[Checkpoint] = None
|
||||||
|
self._load_extra_fn = load_extra_fn
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
self._checkpoint = checkpoint
|
self._checkpoint = checkpoint
|
||||||
|
|
@ -66,6 +71,9 @@ class TrainContextBuilder:
|
||||||
context.optimizer = self.config.optimizer_fn(context.model)
|
context.optimizer = self.config.optimizer_fn(context.model)
|
||||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
|
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
||||||
|
self._load_extra_fn(self._checkpoint.extra, context)
|
||||||
|
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
sampler_offset = context.iteration * cfg.batch_size
|
sampler_offset = context.iteration * cfg.batch_size
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,6 @@ class Trainer:
|
||||||
master_addr=config.master_addr,
|
master_addr=config.master_addr,
|
||||||
master_port=config.master_port,
|
master_port=config.master_port,
|
||||||
device_type=config.device_type,
|
device_type=config.device_type,
|
||||||
device_ids=config.device_ids,
|
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -68,8 +67,9 @@ class Trainer:
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
|
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
if context.iteration % accumulation_steps == 0:
|
||||||
# 2. step
|
# 2. step
|
||||||
self._call_callbacks("on_step_begin", context)
|
self._call_callbacks("on_step_begin", context)
|
||||||
context.optimizer.step()
|
context.optimizer.step()
|
||||||
|
|
@ -83,7 +83,7 @@ class Trainer:
|
||||||
context.iteration += 1
|
context.iteration += 1
|
||||||
|
|
||||||
# to make the loss normalized by accumulation steps
|
# to make the loss normalized by accumulation steps
|
||||||
stand_loss = loss / self.train_config.accumulation_steps
|
stand_loss = loss / accumulation_steps
|
||||||
stand_loss.backward()
|
stand_loss.backward()
|
||||||
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,9 @@ from typing import Any, Dict
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.config import ModelConfig
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.model.transformer import ModelConfig, Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def processor(
|
def processor(
|
||||||
model_dir: str,
|
param_path: str,
|
||||||
input_json_file: str,
|
input_json_file: str,
|
||||||
output_json_file: str,
|
output_json_file: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
|
|
@ -20,8 +20,8 @@ def processor(
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
):
|
):
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model = AutoModel.from_pretrained(model_dir)
|
model = AutoModel.from_pretrained(param_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Create inference engine
|
# Create inference engine
|
||||||
|
|
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
"--param_path", type=str, required=True, help="Path to the model directory."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input_json_file",
|
"--input_json_file",
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
)
|
)
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accumulation_steps",
|
"--accumulation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -53,7 +53,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--warmup_steps",
|
"--warmup_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=1000,
|
default=1000,
|
||||||
help="Number of iters between warnings.",
|
help="Number of warmup steps for LR scheduler.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||||
|
|
@ -98,23 +98,19 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--window_size",
|
"--window_size",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="the max length of the input sequence.",
|
help="Max length of the input sequence.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--stride", type=int, default=None, help="the step size of the input sequence."
|
"--stride", type=int, default=None, help="Step size of the input sequence."
|
||||||
)
|
)
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--on_policy",
|
"--grpo_clip_eps", type=float, default=0.2, help="GRPO clipping epsilon."
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Enable on-policy GRPO mode.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
|
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
|
||||||
)
|
)
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_smoothing",
|
"--label_smoothing",
|
||||||
type=float,
|
type=float,
|
||||||
|
|
@ -134,7 +130,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
default="checkpoint",
|
default="checkpoint",
|
||||||
help="Directory to save checkpoints.",
|
help="Directory to save checkpoints.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grpo_sync_interval",
|
"--grpo_sync_interval",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -160,7 +155,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
local_rank = get_rank()
|
local_rank = get_rank()
|
||||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
model = model.to(dtype=torch.bfloat16)
|
||||||
ddp_model = DDP(
|
ddp_model = DDP(
|
||||||
model,
|
model,
|
||||||
device_ids=[local_rank],
|
device_ids=[local_rank],
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ def test_schedule_factory_random_configs():
|
||||||
|
|
||||||
# Test scheduler step functionality
|
# Test scheduler step functionality
|
||||||
initial_lr = scheduler.get_last_lr()
|
initial_lr = scheduler.get_last_lr()
|
||||||
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
new_lr = scheduler.get_last_lr()
|
new_lr = scheduler.get_last_lr()
|
||||||
|
|
||||||
|
|
@ -112,6 +113,7 @@ def test_schedule_factory_edge_cases():
|
||||||
|
|
||||||
# Test multiple steps
|
# Test multiple steps
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -136,6 +138,7 @@ def test_schedule_factory_state_persistence():
|
||||||
|
|
||||||
# Take a few steps
|
# Take a few steps
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# Save state
|
# Save state
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue