fix: 修复文档多处不准确 + inference scheduler 越界 bug + SchedulerCallback 回调阶段修正
文档 (6 个文件): - design.md: 15+ 处修正 — persistent_key_values→paged_cache, MLA 字段重写, Server/ParallelSetup 不存在类移除, 关系箭头方向修复, SchedulerCallback 阶段修正等 - dataflow.md: 重写数据流图和描述, 修复训练回调顺序、 数据键名、MLA 归属、MetricTracker 等错误 - introduction.md: 层数 32→24, MLP 图双 Linear 修正, 默认值/响应字段/health 端点修复 - params.md: 补充 grpo 及 4 个 GRPO 参数 - README.md / README-zh-CN.md: generate.py 补全必需参数, 删除重复注释, HuggingFace 声明修正 代码 (2 个文件): - scheduler.py: n_pages 池加 page_size 余量防止越界; decode 前预分配页 - train_callback.py: SchedulerCallback 从 on_step_end 改 回 on_batch_end (按 batch 步进学习率)
This commit is contained in:
parent
b98c9cefdc
commit
db99d8b254
|
|
@ -46,7 +46,7 @@
|
|||
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
||||
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
||||
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
||||
- 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets.
|
||||
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
|
||||
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
|
||||
|
||||
### Quick Start
|
||||
|
|
@ -84,7 +84,10 @@ Full reference at [Parameter Guide](assets/docs/params.md).
|
|||
#### Generate Text
|
||||
|
||||
```bash
|
||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||
python scripts/tools/generate.py \
|
||||
--param_path /path/to/model \
|
||||
--input_json_file /path/to/input.json \
|
||||
--output_json_file /path/to/output.json
|
||||
```
|
||||
|
||||
#### Docker
|
||||
|
|
@ -117,8 +120,6 @@ docker compose --profile cpu up -d
|
|||
|
||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
||||
|
||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
||||
|
||||
#### Start HTTP Server
|
||||
|
||||
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@
|
|||
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
||||
- 📦 **轻量**: 依赖少,部署简单。
|
||||
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
||||
- 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。
|
||||
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
|
||||
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
|
||||
|
||||
### 快速开始
|
||||
|
|
@ -90,7 +90,10 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
|||
#### 文本生成
|
||||
|
||||
```bash
|
||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||
python scripts/tools/generate.py \
|
||||
--param_path /path/to/model \
|
||||
--input_json_file /path/to/input.json \
|
||||
--output_json_file /path/to/output.json
|
||||
```
|
||||
|
||||
#### Docker
|
||||
|
|
|
|||
|
|
@ -9,13 +9,11 @@ AstrAI adopts a modular design with the following main components:
|
|||
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
|
||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
|
||||
|
||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||
|
||||
## Data Flow Diagram
|
||||
|
||||
```mermaid
|
||||
|
|
@ -23,38 +21,36 @@ flowchart LR
|
|||
subgraph A[Data Preparation]
|
||||
direction TB
|
||||
A1[Raw Text] --> A2[AutoTokenizer]
|
||||
A2 --> A3[Serialize to .h5 files]
|
||||
A2 --> A3[Tokenized .h5 files]
|
||||
A3 --> A4[BaseDataset]
|
||||
A4 --> A5[ResumableDistributedSampler]
|
||||
A5 --> A6[PyTorch DataLoader]
|
||||
A5 --> A6[DataLoader]
|
||||
end
|
||||
|
||||
subgraph B[Training]
|
||||
direction TB
|
||||
B1[Batch Data] --> B2[TrainContextBuilder]
|
||||
B2 --> B3[TrainContext]
|
||||
B3 --> B4[BaseStrategy]
|
||||
B4 --> B5[Transformer]
|
||||
B5 --> B6[Compute Loss]
|
||||
B6 --> B7[Backward]
|
||||
B7 --> B8[Optimizer]
|
||||
B8 --> B9[LRScheduler]
|
||||
B9 --> B10[CheckpointCallback]
|
||||
B1[DataLoader] --> B2[BaseStrategy]
|
||||
B2 --> B3[Transformer Forward]
|
||||
B3 --> B4[Loss + Backward]
|
||||
B4 --> B5[Gradient Accumulation]
|
||||
B5 -->|every accum_steps| B6[Optimizer Step]
|
||||
B6 --> B7[LR Scheduler]
|
||||
B7 -->|next batch| B2
|
||||
B6 --> B8[CheckpointCallback]
|
||||
end
|
||||
|
||||
subgraph C[Inference]
|
||||
direction TB
|
||||
C1[Checkpoint] --> C2[AutoModel]
|
||||
C2 --> C3[Transformer + Tokenizer]
|
||||
C3 --> C4[GenerationRequest + apply_chat_template]
|
||||
C4 --> C5[InferenceEngine]
|
||||
C5 --> C6[InferenceScheduler]
|
||||
C1 --> C3[AutoTokenizer]
|
||||
C2 --> C4[InferenceEngine]
|
||||
C3 --> C4
|
||||
C4 --> C5[InferenceScheduler]
|
||||
C5 --> C6[Transformer Forward]
|
||||
C6 --> C7[sample]
|
||||
C7 --> C8[Transformer Forward]
|
||||
C8 --> C9[Paged KV Cache]
|
||||
C9 --> C10{End Condition?}
|
||||
C10 -->|No| C8
|
||||
C10 -->|Yes| C11[Output Text]
|
||||
C7 --> C8{End?}
|
||||
C8 -->|No| C6
|
||||
C8 -->|Yes| C9[Generated Text]
|
||||
end
|
||||
|
||||
A --> B
|
||||
|
|
@ -65,215 +61,177 @@ flowchart LR
|
|||
|
||||
### 1. Serialization (`astrai/serialization.py`)
|
||||
|
||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
||||
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
||||
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
|
||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
|
||||
- **`Checkpoint`**: Encapsulates model state dict + epoch + iteration; uses safetensors
|
||||
|
||||
### 2. Dataset Module
|
||||
|
||||
#### 2.1 Dataset (`dataset.py`)
|
||||
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
||||
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
||||
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
||||
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
||||
- **`BaseDataset`**: Abstract base class for windowed sequence sampling
|
||||
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
|
||||
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
|
||||
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO)
|
||||
|
||||
#### 2.2 Sampler (`sampler.py`)
|
||||
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
||||
- Records current epoch and iteration position, enabling training resume from breakpoints
|
||||
- Supports shuffle and drop_last options
|
||||
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
|
||||
|
||||
### 3. Model Module
|
||||
|
||||
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
||||
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
||||
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
||||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
||||
- Supports weight tying (`tie_weight=True`) to reduce parameter count
|
||||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
||||
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
||||
#### 3.1 Transformer / AutoModel
|
||||
- **`AutoModel`**: Base class with `from_pretrained()` / `save_pretrained()`
|
||||
- **`Transformer`**: Decoder-only architecture, registered via `@AutoModel.register('transformer')`
|
||||
- Embedding → N×DecoderBlock → RMSNorm → Linear lm_head
|
||||
- RoPE position encoding, optional weight tying
|
||||
|
||||
#### 3.2 Submodules (`module.py`)
|
||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
||||
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
||||
- **`GQA`**: Grouped Query Attention implementation
|
||||
- **`MLA`**: Multi-Latent Attention implementation (like Qwen2-VL)
|
||||
- **`MLP`**: Feed-forward network with SiLU activation and gated mechanism
|
||||
- **`RMSNorm`**: Layer normalization variant
|
||||
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
||||
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
||||
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
||||
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
||||
- **`RotaryEmbedding`**: RoPE cos/sin cache
|
||||
- **`RMSNorm`**: Layer normalization
|
||||
|
||||
### 4. Training Module
|
||||
|
||||
#### 4.1 Training Context (`train_context.py`)
|
||||
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
|
||||
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
|
||||
- **`TrainContext`**: Dataclass holding model, optimizer, dataloader, strategy, scheduler, checkpoint state
|
||||
- **`TrainContextBuilder`**: Builder pattern — takes checkpoint for resume, builds all components
|
||||
|
||||
#### 4.2 Trainer (`trainer.py`)
|
||||
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
|
||||
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
|
||||
- Training steps include:
|
||||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
||||
|
||||
The training loop is nested: **epoch** → **batch** (with step phase interspersed):
|
||||
|
||||
```
|
||||
on_train_begin
|
||||
on_epoch_begin
|
||||
for each batch:
|
||||
if iteration % accumulation_steps == 0: ← step phase
|
||||
on_step_begin → optimizer.step() → zero_grad → on_step_end
|
||||
← batch phase
|
||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||
iteration += 1
|
||||
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
||||
Key points:
|
||||
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches)
|
||||
- `on_batch_*` wraps loss computation (fires every batch)
|
||||
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch
|
||||
- `GradientClippingCallback` fires on `on_step_begin`
|
||||
|
||||
#### 4.3 Strategy (`strategy.py`)
|
||||
- **`BaseStrategy`**: Defines training strategy interface
|
||||
- **`SEQStrategy`**: Standard next-token prediction training
|
||||
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
||||
- **`DPOStrategy`**: Direct Preference Optimization
|
||||
- **`GRPOStrategy`**: Group Relative Policy Optimization
|
||||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
||||
- Created dynamically by `StrategyFactory` according to configuration
|
||||
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
||||
- **`SFTStrategy`**: Supervised fine-tuning with loss masking
|
||||
- **`DPOStrategy`**: Direct Preference Optimization with reference model
|
||||
- **`GRPOStrategy`**: Group Relative Policy Optimization with clipped ratio
|
||||
|
||||
#### 4.4 Scheduler (`schedule.py`)
|
||||
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
||||
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
||||
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
||||
- Scheduler is automatically created according to configuration and bound to optimizer
|
||||
- **`CosineScheduler`**: Cosine decay + linear warmup
|
||||
- **`SGDRScheduler`**: Cosine annealing with warm restarts
|
||||
- Created by `SchedulerFactory` and bound to optimizer
|
||||
|
||||
#### 4.5 Callbacks (`train_callback.py`)
|
||||
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
||||
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
||||
- **`ProgressBarCallback`**: Displays training progress
|
||||
- **`MetricLoggerCallback`**: Logs training metrics to JSON files
|
||||
- **`GradientClippingCallback`**: Clips gradient norms
|
||||
- **`SchedulerCallback`**: Steps learning rate scheduler
|
||||
#### 4.5 Callbacks
|
||||
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
||||
- **`ProgressBarCallback`**: tqdm progress display
|
||||
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin`
|
||||
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
|
||||
|
||||
#### 4.6 Metric Utility (`metric_util.py`)
|
||||
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
|
||||
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
|
||||
### 5. Inference Module
|
||||
|
||||
### 5. Factory Module
|
||||
#### 5.1 Inference Engine (`engine.py`)
|
||||
- **`InferenceEngine`**: Facade over scheduler; provides `generate()`, `generate_with_request()`, `generate_async()`
|
||||
- Accepts `prompt: str | List[str]`, returns generator (stream) or string (non-stream)
|
||||
|
||||
#### 5.1 Registry and BaseFactory (`factory.py`)
|
||||
- **`Registry`**: Flexible registry for component classes with category and priority support
|
||||
- **`BaseFactory`**: Generic factory class for component registration and creation
|
||||
- Supports decorator-based registration pattern for extensible components
|
||||
- Provides methods for registration, retrieval, and listing with filtering
|
||||
#### 5.2 Scheduler 4-Phase Loop (`scheduler.py`)
|
||||
|
||||
### 6. Parallel Module
|
||||
Background thread runs continuously:
|
||||
|
||||
#### 6.1 Setup (`setup.py`)
|
||||
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
||||
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
||||
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
||||
- **`get_rank`**: Returns current process rank in distributed group
|
||||
- **`get_world_size`**: Returns total number of processes in distributed group
|
||||
- **`get_current_device`**: Returns current device from environment
|
||||
```
|
||||
1. Cleanup → Remove finished tasks, free KV cache pages
|
||||
2. Refill → Pop from waiting_queue, alloc pages, add to active
|
||||
3. Prefill → Group active tasks by prompt_len, run full forward pass
|
||||
4. Decode → Pick largest same-position group, run single-token forward
|
||||
```
|
||||
|
||||
#### 6.2 Parallel Layers (`module.py`)
|
||||
- **`ParallelModel`**: Base class for parallel models with process group
|
||||
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
||||
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
||||
- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
|
||||
- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
|
||||
- **`CacheView`**: Batch view bundling cache + page table for attention layers
|
||||
- **`sample()`**: Temperature → top-k → top-p → multinomial
|
||||
|
||||
### 7. Inference Module
|
||||
#### 5.3 Server (`server.py`)
|
||||
- FastAPI with OpenAI `/v1/chat/completions` and Anthropic `/v1/messages` endpoints
|
||||
- Streaming via SSE, health check at `/health`, stats at `/stats`
|
||||
|
||||
#### 7.1 Inference Engine (`engine.py`)
|
||||
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
|
||||
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
|
||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
|
||||
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
|
||||
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
||||
- Uses separate model and tokenizer initialization for flexibility
|
||||
### 6. Tokenizer Module
|
||||
|
||||
#### 7.2 Cache (`cache.py`)
|
||||
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
|
||||
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
|
||||
- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template`
|
||||
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
|
||||
|
||||
#### 7.3 Scheduler (`scheduler.py`)
|
||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||
- **`TaskStatus`**: Task state enumeration
|
||||
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
|
||||
- Uses `PagedCache` for paged KV cache management with page table indirection
|
||||
- Continuous batching: new requests can join at any time, completed requests release pages immediately
|
||||
### 7. Factory & Parallel
|
||||
|
||||
#### 7.4 Server (`server.py`)
|
||||
- FastAPI-based HTTP inference server
|
||||
- OpenAI-compatible `/v1/chat/completions` endpoint
|
||||
- Health check and statistics endpoints
|
||||
- Supports both streaming and non-streaming responses
|
||||
- **`Registry` / `BaseFactory`**: Decorator-based component registration
|
||||
- **`spawn_parallel_fn`**: Multi-process DDP launcher with NCCL backend
|
||||
- **`ParallelModel` / `ColumnParallelLinear` / `RowParallelLinear`**: Tensor model parallelism
|
||||
|
||||
### 8. Tokenizer Module
|
||||
|
||||
#### 8.1 Tokenizer (`tokenizer.py`)
|
||||
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
||||
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
||||
|
||||
#### 8.2 Chat Template (`chat_template.py`)
|
||||
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
||||
- Handles multi-role message formatting (system, user, assistant)
|
||||
- Supports dynamic prompts and generation prompts
|
||||
|
||||
## Training Data Flow - Detailed Steps
|
||||
## Training Data Flow — Detailed Steps
|
||||
|
||||
1. **Data Preparation**
|
||||
- Raw text is converted to token ID sequences through AutoTokenizer
|
||||
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
|
||||
- Files can contain multiple segments, each segment corresponds to a tensor
|
||||
- Raw text → token IDs via `AutoTokenizer.encode()`
|
||||
- Save as `.h5` files (groups of tensor lists per data key)
|
||||
|
||||
2. **Dataset Loading**
|
||||
- `BaseDataset`'s `load` method calls `load_h5`, obtaining `segments` dictionary
|
||||
- Create `MultiSegmentFetcher` to manage data for multiple keys
|
||||
- Calculate total sample count, and determine start/end indices for each sample based on window size and stride
|
||||
- `BaseDataset.load()` calls `load_h5()`, builds `MultiSegmentFetcher`
|
||||
- Sliding window of `window_size` with `stride` determines sample boundaries
|
||||
|
||||
3. **Sampling and Batch Loading**
|
||||
- `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position
|
||||
- PyTorch `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
|
||||
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
|
||||
3. **Sampling & Batching**
|
||||
- `ResumableDistributedSampler` produces shuffled index sequences
|
||||
- `DataLoader` fetches `[batch_size, window_size]` tensors via `__getitem__`
|
||||
|
||||
4. **Strategy Forward and Loss Calculation**
|
||||
- Batch data is passed to strategy (such as `SEQStrategy`)
|
||||
- Strategy internally calls `Transformer` model, obtaining logits
|
||||
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
|
||||
- Return loss tensor
|
||||
4. **Strategy Forward**
|
||||
- Strategy receives batch, calls `Transformer.forward()` for logits
|
||||
- Computes task-specific loss (cross-entropy, DPO, GRPO)
|
||||
|
||||
5. **Backpropagation and Optimization**
|
||||
- Loss is normalized by dividing by accumulation steps, then `loss.backward()` is executed
|
||||
- After accumulating `accumulation_steps` batches, optimizer `step()` and `zero_grad()` are executed
|
||||
- Learning rate scheduler updates learning rate after each step
|
||||
5. **Backward & Accumulation**
|
||||
- `loss = raw_loss / accumulation_steps`
|
||||
- `loss.backward()` accumulates gradients
|
||||
- Every `accumulation_steps` batches: `optimizer.step()` → `zero_grad()`
|
||||
- Every batch: `scheduler.step()` updates learning rate
|
||||
|
||||
6. **Checkpoint Saving**
|
||||
- `CheckpointCallback` saves checkpoints at set intervals
|
||||
- Checkpoints contain model state dict, current epoch, iteration, and other metadata
|
||||
- Saved in safetensors format, ensuring safety and efficiency
|
||||
6. **Checkpoint**
|
||||
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
|
||||
- Does NOT save optimizer/scheduler state (resume resets those)
|
||||
|
||||
## Inference Data Flow - Detailed Steps
|
||||
## Inference Data Flow — Detailed Steps
|
||||
|
||||
1. **Model Loading**
|
||||
- Load `Transformer` model from checkpoint via `AutoModel.from_pretrained()`
|
||||
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
|
||||
- `AutoModel.from_pretrained(path)` loads weights from safetensors
|
||||
- `torch.inference_mode()` wraps generation
|
||||
|
||||
2. **Prompt Construction and Encoding**
|
||||
- User messages (list of dict with role and content) are converted to ChatML format string through `apply_chat_template` method in tokenizer
|
||||
- Tokenizer encodes prompt string to token ID sequence `input_ids`
|
||||
- For batch generation, use `pad_sequence` for padding
|
||||
2. **Prompt Construction**
|
||||
- Messages → `apply_chat_template(messages, tokenize=False)` → prompt string
|
||||
- `tokenizer.encode(prompt)` → token IDs (truncated to `max_prompt_len`)
|
||||
|
||||
3. **Autoregressive Generation Loop**
|
||||
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
|
||||
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
|
||||
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
|
||||
- Input last token ID to model, obtain `logits`
|
||||
- Apply `sample()` (temperature, top-k, top-p) to `logits`
|
||||
- Sample next token ID from the processed distribution
|
||||
- Write new KV entries into paged cache; allocate additional pages as needed
|
||||
- For streaming generation, yield each token to caller immediately via `stream_callback`
|
||||
3. **Continuous Batching Loop**
|
||||
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
|
||||
- **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
|
||||
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
|
||||
- **Decode**: Pick position group with most tasks, single-token forward:
|
||||
- Model forward → `logits` → `sample()` → next token ID
|
||||
- Append to `output_ids`, update `output_tokens`
|
||||
- `_maybe_alloc_page()` grows page table as needed
|
||||
- `stream_callback(token)` for streaming clients
|
||||
|
||||
4. **Decoding and Output**
|
||||
- Decode generated token ID sequence to text through tokenizer
|
||||
- Remove special tokens, return plain text response
|
||||
4. **Output**
|
||||
- `tokenizer.decode(output_ids)` → text
|
||||
- Return to caller (streaming: token-by-token; non-streaming: complete string)
|
||||
|
||||
## Checkpoint and Serialization
|
||||
## Checkpoint & Serialization
|
||||
|
||||
- **Training Checkpoint**: Saves model parameters, optimizer state, scheduler state, current epoch and iteration
|
||||
- **Model Parameters**: Supports safetensors format, automatically handles special logic like weight tying during loading
|
||||
- **Dataset Serialization**: HDF5 format supports efficient random access and shared memory, suitable for large-scale pre-training data
|
||||
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted.
|
||||
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
|
||||
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
|
||||
|
||||
## Summary
|
||||
|
||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||
|
||||
> Document Update Time: 2026-04-09
|
||||
> Document Update Time: 2026-05-09
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ classDiagram
|
|||
+str master_port
|
||||
+Callable parallel_wrapper
|
||||
+Callable state_dict_fn
|
||||
+List[int] device_ids
|
||||
+str device_type
|
||||
+dict extra_kwargs
|
||||
+validate()
|
||||
|
|
@ -99,8 +98,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class ResumableDistributedSampler {
|
||||
+int start_epoch
|
||||
+int start_iter
|
||||
+int epoch
|
||||
+int iter
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
|
|
@ -124,7 +123,7 @@ classDiagram
|
|||
namespace model {
|
||||
class AutoModel {
|
||||
+ModelConfig config
|
||||
+Dict _registry
|
||||
+Registry _registry
|
||||
+register(model_type) decorator
|
||||
+get_model_class(model_type) Type
|
||||
+from_pretrained(path, disable_random_init) nn.Module
|
||||
|
|
@ -139,7 +138,7 @@ classDiagram
|
|||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
+Linear lm_head
|
||||
+forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
|
||||
+forward(input_ids, input_mask, paged_cache, start_pos) Dict
|
||||
+load_state_dict(state_dict)
|
||||
+state_dict()
|
||||
}
|
||||
|
|
@ -149,7 +148,7 @@ classDiagram
|
|||
+RMSNorm input_norm
|
||||
+MLP mlp
|
||||
+RMSNorm post_attention_norm
|
||||
+forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
|
||||
+forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class GQA {
|
||||
|
|
@ -158,18 +157,20 @@ classDiagram
|
|||
+int head_dim
|
||||
+Linear q_proj, k_proj, v_proj, o_proj
|
||||
+RMSNorm q_norm, k_norm
|
||||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class MLA {
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+int head_dim
|
||||
+Linear q_a_proj, q_b_proj, q_c_proj
|
||||
+Linear kv_a_proj, kv_b_proj, kv_c_proj
|
||||
+int kv_lora_rank
|
||||
+int qk_nope_head_dim
|
||||
+int qk_rope_head_dim
|
||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||
+Linear o_proj
|
||||
+RMSNorm q_norm, k_norm
|
||||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||
+RMSNorm kv_norm
|
||||
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class MLP {
|
||||
|
|
@ -204,7 +205,7 @@ classDiagram
|
|||
|
||||
namespace tokenize {
|
||||
class AutoTokenizer {
|
||||
+List[str] stop_ids
|
||||
+List[int] stop_ids
|
||||
+int bos_id
|
||||
+int eos_id
|
||||
+int pad_id
|
||||
|
|
@ -220,7 +221,7 @@ classDiagram
|
|||
|
||||
class ChatTemplate {
|
||||
+String template_str
|
||||
+render(messages, add_generation_prompt) str
|
||||
+render(messages, system_prompt, **extra_variables) str
|
||||
+from_string(template) ChatTemplate
|
||||
}
|
||||
}
|
||||
|
|
@ -267,8 +268,6 @@ classDiagram
|
|||
class TrainContextBuilder {
|
||||
+TrainConfig config
|
||||
+with_checkpoint(checkpoint) TrainContextBuilder
|
||||
+with_dataloader() TrainContextBuilder
|
||||
+with_strategy() TrainContextBuilder
|
||||
+build() TrainContext
|
||||
}
|
||||
|
||||
|
|
@ -454,7 +453,7 @@ classDiagram
|
|||
+float arrival_time
|
||||
+float finish_time
|
||||
+Callable stream_callback
|
||||
+next_pos() int
|
||||
+int next_pos
|
||||
+is_finished(stop_ids) bool
|
||||
}
|
||||
|
||||
|
|
@ -506,15 +505,10 @@ classDiagram
|
|||
+sample(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class Server {
|
||||
+start()
|
||||
+predict(request)
|
||||
}
|
||||
|
||||
class _Result {
|
||||
+List[str] tokens
|
||||
+List[str] results
|
||||
+List[bool] done_flags
|
||||
+List[bool] _done
|
||||
+append(token, idx)
|
||||
+get_results() List[str]
|
||||
+pop_all() List[str]
|
||||
|
|
@ -539,9 +533,9 @@ classDiagram
|
|||
}
|
||||
|
||||
namespace parallel {
|
||||
class ParallelSetup {
|
||||
class ParallelFunctions {
|
||||
+spawn_parallel_fn(fn, nprocs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
}
|
||||
|
||||
class ParallelModel {
|
||||
|
|
@ -601,24 +595,19 @@ classDiagram
|
|||
BaseSamplingStrategy <|-- TopKStrategy
|
||||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
SamplingPipeline --> BaseSamplingStrategy : composes
|
||||
Server --> InferenceEngine : uses
|
||||
Server --> ChatMessage : uses
|
||||
Server --> ChatCompletionRequest : uses
|
||||
ParallelSetup --> Trainer : enables
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
BaseDataset <|-- GRPODataset
|
||||
DatasetFactory ..> BaseDataset : creates
|
||||
BaseSegmentFetcher --> MultiSegmentFetcher : used by
|
||||
MultiSegmentFetcher --> BaseDataset : used by
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
||||
BaseDataset --> MultiSegmentFetcher : uses
|
||||
AutoModel <|-- Transformer
|
||||
AutoModel --> ModelConfig : contains
|
||||
Transformer --> DecoderBlock : uses
|
||||
Transformer --> RotaryEmbedding : uses
|
||||
Transformer --> Embedding : uses
|
||||
DecoderBlock --> GQA : uses
|
||||
DecoderBlock --> MLA : uses
|
||||
DecoderBlock --> MLP : uses
|
||||
DecoderBlock --> RMSNorm : uses
|
||||
TrainContextBuilder --> ResumableDistributedSampler : creates
|
||||
|
|
@ -647,7 +636,7 @@ classDiagram
|
|||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||
|
||||
### Design Patterns
|
||||
|
|
@ -658,7 +647,7 @@ classDiagram
|
|||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Singleton** | `TrainContext` | Training process global state management |
|
||||
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||
|
|
@ -672,8 +661,8 @@ classDiagram
|
|||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
||||
|
|
@ -717,12 +706,6 @@ $$
|
|||
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
|
||||
$$
|
||||
|
||||
In this implementation, an off-policy approach is used ($\pi_\theta = \pi_{\text{ref}}$), and the policy loss simplifies to:
|
||||
|
||||
$$
|
||||
L_{\text{policy}} = -\mathbb{E}[A]
|
||||
$$
|
||||
|
||||
The KL divergence term uses mean squared error approximation:
|
||||
|
||||
$$
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
### 1. Model Architecture
|
||||
|
||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||
|
||||
The model now uses the **AutoModel** base class for flexible loading and saving:
|
||||
|
||||
|
|
@ -48,14 +48,15 @@ flowchart TB
|
|||
S --> T[+]
|
||||
H --> T
|
||||
T --> U[RMSNorm]
|
||||
U --> V[Linear]
|
||||
V --> W[SiLU]
|
||||
V --> X[×]
|
||||
W --> X
|
||||
X --> Y[Linear]
|
||||
Y --> Z[+]
|
||||
T --> Z
|
||||
Z --> AA[x']
|
||||
U --> V["Linear (gate)"]
|
||||
U --> W["Linear (up)"]
|
||||
V --> X[SiLU]
|
||||
X --> Y[×]
|
||||
W --> Y
|
||||
Y --> Z["Linear (down)"]
|
||||
Z --> AA[+]
|
||||
T --> AA
|
||||
AA --> BB[x']
|
||||
end
|
||||
|
||||
classDef main fill:#e6f3ff,stroke:#0066cc;
|
||||
|
|
@ -168,8 +169,6 @@ from astrai.inference import InferenceEngine, GenerationRequest
|
|||
engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=8,
|
||||
max_seq_len=4096,
|
||||
)
|
||||
|
||||
# Use GenerationRequest with messages format
|
||||
|
|
@ -222,12 +221,11 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
|||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `messages` | List[dict] | Required | Chat messages with role and content |
|
||||
| `temperature` | float | 0.8 | Sampling temperature (0.0-2.0) |
|
||||
| `top_p` | float | 0.95 | Nucleus sampling threshold |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
|
||||
| `top_p` | float | 1.0 | Nucleus sampling threshold |
|
||||
| `top_k` | int | 50 | Top-k sampling parameter |
|
||||
| `max_tokens` | int | 2048 | Maximum tokens to generate |
|
||||
| `max_tokens` | int | 1024 | Maximum tokens to generate |
|
||||
| `stream` | bool | false | Enable streaming response |
|
||||
| `system_prompt` | str | None | System prompt override |
|
||||
|
||||
**Response (non-streaming):**
|
||||
```json
|
||||
|
|
@ -242,7 +240,12 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
|||
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -262,9 +265,6 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
|||
|
||||
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
||||
|
||||
### Health Check
|
||||
|
||||
|
||||
### Anthropic-Compatible Endpoint
|
||||
|
||||
The server also provides an Anthropic-compatible endpoint at `/v1/messages`:
|
||||
|
|
@ -325,10 +325,10 @@ Monitor server and model status:
|
|||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
# {"status": "ok", "model_loaded": true, "engine_ready": true}
|
||||
# {"status": "ok", "model_loaded": true}
|
||||
|
||||
curl http://localhost:8000/stats
|
||||
# {"requests_total": 10, "tokens_generated": 5000, ...}
|
||||
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
||||
| `--data_root_path` | Dataset root directory | required |
|
||||
| `--param_path` | Model parameters or checkpoint path | required |
|
||||
| `--n_epoch` | Total training epochs | 1 |
|
||||
|
|
@ -61,6 +61,10 @@
|
|||
|-----------|-------------|---------|---------|
|
||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
||||
|
||||
### Usage Example
|
||||
|
||||
|
|
|
|||
|
|
@ -105,7 +105,9 @@ class InferenceScheduler:
|
|||
n_kv_heads = config.n_kv_heads
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_layers = config.n_layers
|
||||
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
|
||||
n_pages = (
|
||||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||
) // page_size
|
||||
|
||||
self.page_cache = PagedCache(
|
||||
n_layers,
|
||||
|
|
@ -279,6 +281,9 @@ class InferenceScheduler:
|
|||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||
batch_sz = len(tasks)
|
||||
|
||||
for t in tasks:
|
||||
self._maybe_alloc_page(t, start_pos)
|
||||
|
||||
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||
for i, t in enumerate(tasks):
|
||||
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class SchedulerCallback(TrainCallback):
|
|||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue