From 3d12a03909c6dedc6de112a4f53e3ecd1d1a2068 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 15 May 2026 23:13:03 +0800 Subject: [PATCH] =?UTF-8?q?docs=20:=20=E6=8B=86=E5=88=86=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=B9=B6=E8=A1=A5=E5=85=85=E7=B1=BB=E5=9B=BE=E7=BC=BA=E5=A4=B1?= =?UTF-8?q?=E7=B1=BB=E5=92=8C=E5=85=B3=E7=B3=BB=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 design.md 拆分为 architecture.md / inference.md / training.md - 精简 dataflow.md 为纯数据管道 - 删除 design.md 和 introduction.md - 更新 README.md 和 README-zh-CN.md 链接 - 补充 ChatMessage / AnthropicMessage 等 6 条孤立类关系线 - 补充 BaseModelConfig 和 TaskManager 两个缺失类 --- README.md | 7 +- assets/docs/README-zh-CN.md | 7 +- assets/docs/{design.md => architecture.md} | 146 ++++----- assets/docs/dataflow.md | 246 ++------------- assets/docs/inference.md | 140 +++++++++ assets/docs/introduction.md | 334 --------------------- assets/docs/training.md | 199 ++++++++++++ 7 files changed, 438 insertions(+), 641 deletions(-) rename assets/docs/{design.md => architecture.md} (77%) create mode 100644 assets/docs/inference.md delete mode 100644 assets/docs/introduction.md create mode 100644 assets/docs/training.md diff --git a/README.md b/README.md index 4c58d72..a268082 100644 --- a/README.md +++ b/README.md @@ -208,9 +208,10 @@ Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYH | Document | Description | |----------|-------------| | [Parameter Guide](./assets/docs/params.md) | Training & inference parameters | -| [Design Document](./assets/docs/design.md) | Framework architecture & module design | -| [Data Flow](./assets/docs/dataflow.md) | Data processing pipeline details | -| [Model Introduction](./assets/docs/introduction.md) | Model architecture & technical details | +| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns | +| [Training](./assets/docs/training.md) | Training loop, strategies & formulas | +| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API | +| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture | ### Contributing diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index 1b87880..13cf22b 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -214,9 +214,10 @@ python scripts/demo/generate_ar.py | 文档 | 说明 | |------|------| | [参数说明](./params.md) | 训练与推理参数配置 | -| [设计文档](./design.md) | 系统架构与模块设计 | -| [数据流程](./dataflow.md) | 数据处理管道详解 | -| [模型介绍](./introduction.md) | 模型架构与技术细节 | +| [架构文档](./architecture.md) | 系统架构、类图与设计模式 | +| [训练文档](./training.md) | 训练循环、策略与公式 | +| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API | +| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 | ### 贡献 diff --git a/assets/docs/design.md b/assets/docs/architecture.md similarity index 77% rename from assets/docs/design.md rename to assets/docs/architecture.md index b9fc068..805f7c4 100644 --- a/assets/docs/design.md +++ b/assets/docs/architecture.md @@ -1,14 +1,16 @@ -## 1. Why I Created This Project +# AstrAI Architecture -There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence. - -Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, and the training code is open source! - -## 2. System Architecture +## Class Diagram ```mermaid classDiagram namespace config { + class BaseModelConfig { + +Optional[str] model_type + +load(config_path) Self + +save(config_path) + } + class ModelConfig { +int vocab_size +int dim @@ -565,6 +567,19 @@ classDiagram ABORTED } + class TaskManager { + +AutoTokenizer tokenizer + +Deque waiting_queue + +List active_tasks + +add_task(prompt, **kwargs) str + +remove_task(task_id) List[Task] + +remove_finished_tasks(stop_ids) List[Task] + +pull_candidates(n) List[Task] + +activate(task) + +return_to_waiting(tasks) + +get_active_tasks() List[Task] + } + class GenerationRequest { +List[Dict] messages +int top_k @@ -736,6 +751,7 @@ classDiagram ParallelModel <|-- RowParallelLinear ParallelModel <|-- ColumnParallelLinear AutoModel <|-- Transformer + BaseModelConfig <|-- ModelConfig BaseFactory <|-- AutoModel BaseFactory <|-- AttnFactory BaseFactory <|-- FFNFactory @@ -763,6 +779,8 @@ classDiagram Transformer *-- Embedding DecoderBlock *-- RMSNorm BaseDataset *-- BaseStorage + ChatCompletionRequest *-- ChatMessage + MessagesRequest *-- AnthropicMessage %% --- Aggregation (weak ownership) --- AutoModel o-- ModelConfig @@ -795,6 +813,10 @@ classDiagram KVCache ..> KvcacheView : binds InferenceEngine ..> GenerationRequest : uses InferenceEngine ..> GenerateResult : creates + OpenAIHandler ..> ChatCompletionRequest : receives + AnthropicHandler ..> MessagesRequest : receives + ProtocolHandler ..> StopChecker : creates + ProtocolHandler ..> StreamContext : creates %% --- Association (general usage) --- Trainer --> TrainConfig @@ -809,99 +831,51 @@ classDiagram TaskManager --> AutoTokenizer MultiSegmentFetcher --> BaseSegmentFetcher ResumableDistributedSampler --> BaseDataset + ``` -### Module Overview + +## Module Overview | Module | Components | Description | |--------|------------|-------------| | **astrai.config** | ModelConfig, TrainConfig | Configuration management | -| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5, save_json, load_json, create_storage, detect_format | Dataset loading and management | -| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management | +| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | +| **astrai.serialization** | Checkpoint | Model serialization | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | -| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | -| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample, ChatMessage, ChatCompletionRequest, AnthropicMessage, MessagesRequest, OpenAIHandler, AnthropicHandler, ProtocolHandler, StreamContext, StopChecker, app, run_server | Inference service with continuous batching and paged KV cache | -| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, only_on_rank, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel | -| **astrai.factory** | Registry, BaseFactory[T] | Generic component registration with decorator pattern | +| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–MetricLoggerCallback, CallbackFactory | Training workflow | +| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service | +| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel | +| **astrai.factory** | Registry, BaseFactory[T] | Component registration | -### Design Patterns +## Design Patterns | Pattern | Classes | Purpose | |---------|---------|---------| -| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | -| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | -| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks | -| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, gradient clipping, metrics) | -| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint | -| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | -| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction | -| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p | -| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | -| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | -| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | -| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation | -| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | `handle()` template with stream/non-stream branches, protocol-specific format hooks | -| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage`, `_STORAGE_REGISTRY` | Format-agnostic data access with registry-dispatch (HDF5 / JSON) | +| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation | +| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | +| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | +| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | +| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks | +| **Builder** | `TrainContextBuilder` | Chain-building training context | +| **Observer** | `TrainCallback`, callback implementations | Training process monitoring | +| **Context** | `TrainContext` | Unified training state bag | +| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | +| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | +| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | +| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading | -### Core Relationships +## Core Relationships -1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references -2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss -3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` -4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming -5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer` -6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 and JSON loading via `BaseStorage` (`H5Storage` / `JSONStorage`) with `BaseSegmentFetcher` and `MultiSegmentFetcher` -7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors -8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler) -9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration - -## 3. Training Process - -The common training process for large language models (LLM) typically includes three stages: **Pre-training (SEQ)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (DPO/GRPO)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies. - -### Core Formulas - -**Pre-training (SEQ):** - -$$ -L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta) -$$ - -**SFT:** - -$$ -L_{\text{SFT}} = - \sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta) -$$ - -**DPO:** - -$$ -L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right] -$$ - -**GRPO:** - -GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective: - -$$ -\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon} -$$ - -Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards. - -$$ -L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL} -$$ - -The KL divergence term uses mean squared error approximation: - -$$ -L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right] -$$ - -The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$ - -Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. +1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn +2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss +3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` +4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, backed by `KVCache` + `SamplingPipeline` +5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP +6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` +7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only) +8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` +9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops > Document Update Time: 2026-05-15 diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index b5b2f18..781e569 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -1,241 +1,57 @@ -# AstrAI Data Flow Documentation +# Data Flow -This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference. +This document describes the data pipeline: from raw text to model input tensors. ## Overview -AstrAI adopts a modular design with the following main components: -- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, storage backends -- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules -- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities -- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation -- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig -- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration -- **Parallel Module** (`astrai/parallel/`): Distributed training support -- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors - -## Data Flow Diagram - -```mermaid -flowchart LR - subgraph A[Data Preparation] - direction TB - A1[Raw Text] --> A2[AutoTokenizer] - A2 --> A3[Tokenized .h5 files] - A3 --> A4[BaseDataset] - A4 --> A5[ResumableDistributedSampler] - A5 --> A6[DataLoader] - end - - subgraph B[Training] - direction TB - B1[DataLoader] --> B2[BaseStrategy] - B2 --> B3[Transformer Forward] - B3 --> B4[Loss + Backward] - B4 --> B5[Gradient Accumulation] - B5 -->|every accum_steps| B6[Optimizer Step] - B6 --> B7[LR Scheduler] - B7 -->|next batch| B2 - B6 --> B8[CheckpointCallback] - end - - subgraph C[Inference] - direction TB - C1[Checkpoint] --> C2[AutoModel] - C1 --> C3[AutoTokenizer] - C2 --> C4[InferenceEngine] - C3 --> C4 - C4 --> C5[InferenceScheduler] - C5 --> C6[Transformer Forward] - C6 --> C7[sample] - C7 --> C8{End?} - C8 -->|No| C6 - C8 -->|Yes| C9[Generated Text] - end - - A --> B - B --> C +``` +Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference ``` -## Detailed Module Descriptions +## Data Preparation -### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`) +Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups. -- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors -- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory -- **`Checkpoint`**: Encapsulates model state dict + epoch + iteration; uses safetensors - -### 2. Dataset Module - -#### 2.1 Dataset (`dataset.py`) -- **`BaseDataset`**: Abstract base class for windowed sequence sampling -- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range -- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`) -- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen"/"rejected"` (DPO), `"prompts"/"responses"/"masks"/"rewards"` (GRPO) -- Storage backends: HDF5 (`.h5`) or JSON (`.json`/`.jsonl`), auto-detected by `detect_format()` - -#### 2.2 Sampler (`sampler.py`) -- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last - -### 3. Model Module - -#### 3.1 Transformer / AutoModel -- **`AutoModel`**: Base class with `from_pretrained()` / `save_pretrained()` -- **`Transformer`**: Decoder-only architecture, registered via `@AutoModel.register('transformer')` -- Embedding → N×DecoderBlock → RMSNorm → Linear lm_head -- RoPE position encoding, optional weight tying - -#### 3.2 Submodules (`module.py`) -- **`DecoderBlock`**: Pre-LN (norm→attention→residual, norm→MLP→residual), uses `AttnFactory` / `FFNFactory` -- **`GQA`**: Grouped Query Attention (q_heads ÷ kv_heads = n_rep) -- **`MLA`**: Multi-head Latent Attention with KV compression (kv_lora_rank) -- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection -- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis) -- **`RMSNorm`**: Layer normalization - -### 4. Training Module - -#### 4.1 Training Context (`train_context.py`) -- **`TrainContext`**: Dataclass holding model, optimizer, dataloader, strategy, scheduler, checkpoint state -- **`TrainContextBuilder`**: Builder pattern — takes checkpoint for resume, builds all components - -#### 4.2 Trainer (`trainer.py`) - -The training loop is nested: **epoch** → **batch** (with step phase interspersed): +Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: ``` -on_train_begin - on_epoch_begin - for each accumulation window of batches: ← step phase - on_step_begin - for each batch in window: ← batch phase - on_batch_begin → strategy(batch) → loss - (loss / window_size).backward() → on_batch_end - iteration += 1 - on_step_end - optimizer.step() → zero_grad - scheduler.step() ← per step, not per batch - - on_epoch_end -on_train_end +create_storage("h5") → H5Storage +create_storage("json") → JSONStorage ``` -Key points: -- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook -- `on_batch_*` fires every batch, wrapping loss computation -- `GradientClippingCallback` fires on `on_step_end` -- LR scheduler steps inline (no `SchedulerCallback` class), once per optimizer step +Both support shared memory via `.share_memory_()`. -#### 4.3 Strategy (`strategy.py`) -- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing -- **`SFTStrategy`**: Supervised fine-tuning with loss masking -- **`DPOStrategy`**: Direct Preference Optimization with reference model -- **`GRPOStrategy`**: Group Relative Policy Optimization with clipped ratio +## Data Keys by Training Type -#### 4.4 Scheduler (`schedule.py`) -- **`CosineScheduler`**: Cosine decay + linear warmup -- **`SGDRScheduler`**: Cosine annealing with warm restarts -- Created by `SchedulerFactory` and bound to optimizer +| Type | Storage Keys | +|------|-------------| +| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) | +| `sft` | `sequence`, `loss_mask` | +| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` | +| `grpo` | `prompts`, `responses`, `masks`, `rewards` | -#### 4.5 Callbacks -- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations -- **`ProgressBarCallback`**: tqdm progress display -- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/` -- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end` - -### 5. Inference Module - -#### 5.1 Inference Engine (`engine.py`) -- **`InferenceEngine`**: Facade over scheduler; provides `generate()`, `generate_with_request()`, `generate_async()` -- Accepts `prompt: str | List[str]`, returns generator (stream) or string (non-stream) - -#### 5.2 Scheduler 4-Phase Loop (`scheduler.py`) - -Background thread runs continuously: +## Dataset Architecture ``` -1. Cleanup → Remove finished tasks, free KV cache pages -2. Refill → Pop from waiting_queue, alloc pages, add to active -3. Prefill → Group active tasks by prompt_len, run full forward pass -4. Decode → Pick largest same-position group, run single-token forward +DatasetFactory.load(train_type, path, window_size, stride) + → create_storage(detect_format(path)) + → MultiSegmentFetcher(BaseSegmentFetcher per key) + → BaseDataset.__getitem__(idx) + → sliding window [begin, end) via get_index(idx) ``` -- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED) -- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache -- **`KvcacheView`**: Batch view bundling cache + page table for attention layers -- **`sample()`**: Temperature → top-k → top-p → multinomial +`window_size` = max input length, `stride` = step between consecutive samples. -#### 5.3 Server (`server.py`) -- FastAPI with OpenAI `/v1/chat/completions` and Anthropic `/v1/messages` endpoints -- Streaming via SSE, health check at `/health`, stats at `/stats` +## Sampler -### 6. Tokenizer Module +`ResumableDistributedSampler` supports checkpoint-aware distributed sampling: -- **`AutoTokenizer`**: Wraps HuggingFace `tokenizers.Tokenizer` (not `transformers`); `encode`/`decode`/`apply_chat_template` -- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat +- Tracks `start_epoch` / `start_iter` for resume +- Shuffle via `torch.Generator(seed + epoch)` +- Per-replica index slicing for DDP -### 7. Factory & Parallel +## DataLoader -- **`Registry` / `BaseFactory`**: Decorator-based component registration -- **`spawn_parallel_fn`**: Multi-process DDP launcher with NCCL backend -- **`ParallelModel` / `ColumnParallelLinear` / `RowParallelLinear`**: Tensor model parallelism - -## Training Data Flow — Detailed Steps - -1. **Data Preparation** - - Raw text → token IDs via `AutoTokenizer.encode()` - - Save as `.h5` files (groups of tensor lists per data key) - -2. **Dataset Loading** - - `BaseDataset.load()` calls `load_h5()`, builds `MultiSegmentFetcher` - - Sliding window of `window_size` with `stride` determines sample boundaries - -3. **Sampling & Batching** - - `ResumableDistributedSampler` produces shuffled index sequences - - `DataLoader` fetches `[batch_size, window_size]` tensors via `__getitem__` - -4. **Strategy Forward** - - Strategy receives batch, calls `Transformer.forward()` for logits - - Computes task-specific loss (cross-entropy, DPO, GRPO) - -5. **Backward & Accumulation** - - `stand_loss = loss / step_batch_nums` (divide by actual batch count in this window) - - `stand_loss.backward()` accumulates gradients - - Every `accumulation_steps` batches: `optimizer.step()` → `zero_grad()` - - Every optimizer step: `scheduler.step()` updates learning rate - -6. **Checkpoint** - - `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations - - Does NOT save optimizer/scheduler state by default; `Checkpoint.extra` or `save_extra_fn` can store arbitrary additional data - -## Inference Data Flow — Detailed Steps - -1. **Model Loading** - - `AutoModel.from_pretrained(path)` loads weights from safetensors - - `torch.inference_mode()` wraps generation - -2. **Prompt Construction** - - Messages → `apply_chat_template(messages, tokenize=False)` → prompt string - - `tokenizer.encode(prompt)` → token IDs (truncated to `max_prompt_len`) - -3. **Continuous Batching Loop** - - **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages - - **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages - - **Prefill**: Group by prompt length, run full forward with `start_pos=0` - - **Decode**: Pick position group with most tasks, single-token forward: - - Model forward → `logits` → `sample()` → next token ID - - Append to `output_ids`, update `output_tokens` - - `PagePool.task_alloc()` allocates pages as needed - - `stream_callback(token)` for streaming clients - -4. **Output** - - `tokenizer.decode(output_ids)` → text - - Return to caller (streaming: token-by-token; non-streaming: complete string) - -## Checkpoint & Serialization - -- **Training Checkpoint**: safetensors weights + epoch/iteration metadata + optional extras. Optimizer/scheduler state is NOT persisted by default but can be stored via `extra`. -- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format. -- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data. +Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. > Document Update Time: 2026-05-15 diff --git a/assets/docs/inference.md b/assets/docs/inference.md new file mode 100644 index 0000000..24e1a05 --- /dev/null +++ b/assets/docs/inference.md @@ -0,0 +1,140 @@ +# Inference + +## KV Cache + +At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation: + +$$ +o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j +$$ + +RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs. + +## KVCache System + +Six classes working together: + +``` +KVCache (facade) + ├── Allocator bitmask-based page allocator + ref-count + LRU eviction + ├── PrefixCache hash-based prefix matching (page_hash via rolling hash) + ├── PagePool orchestrates Allocator + PrefixCache + ├── TaskTable maps task_id → page_table + cached token count + ├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim) + └── KvcacheView bundles Storage + page_table + total_len for attention layers +``` + +`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`. + +## Continuous Batching + +`InferenceScheduler` runs a daemon thread with a 4-phase loop: + +``` +1. Cleanup → Remove finished tasks, free KV pages +2. Refill → Pop from waiting_queue, task_alloc pages, activate +3. Prefill → Group by (prompt_len, start_pos), run full forward +4. Decode → Pick largest same-position group, single-token forward +``` + +## Sampling (Strategy Pattern) + +``` +BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy +``` + +`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. +`sample()` is a convenience shortcut for one-shot usage. + +## Protocol Handlers (Template Method) + +```python +class ProtocolHandler(ABC): + def handle(self): + ctx = StreamContext(...) + agen = engine.generate_async(prompt, ...) + if stream: self._handle_stream(agen, ctx) + else: self._handle_non_stream(agen, ctx) +``` + +Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`. + +`OpenAIHandler` → `/v1/chat/completions`, `AnthropicHandler` → `/v1/messages`. + +## Engine & GenerateResult + +``` +InferenceEngine + ├── generate(prompt, stream, ...) → str | List[str] | Generator + ├── generate_with_request(req) → same + └── generate_async(prompt, ...) → AsyncGenerator +``` + +`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`. + +## HTTP API + +``` +POST /v1/chat/completions OpenAI +POST /v1/messages Anthropic +GET /health {"status":"ok","model_loaded":true} +GET /stats scheduler statistics +``` + +### OpenAI + +```bash +curl -X POST http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}' +``` + +Response: +```json +{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} +} +``` + +Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]` + +### Anthropic + +```bash +curl -X POST http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}' +``` + +Supports `stop_sequences` and streaming via `event: content_block_delta`. + +### GenerationRequest Parameters + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `messages` | List[dict] | required | Chat messages (role, content) | +| `temperature` | float | 1.0 | Sampling temperature (0.0–2.0) | +| `top_p` | float | 1.0 | Nucleus threshold | +| `top_k` | int | 50 | Top-k count | +| `max_tokens` | int | None | Max generation length | +| `stream` | bool | False | Stream output | + +## Engine API + +```python +# Non-streaming +engine.generate("Hello", stream=False) # -> str +engine.generate(["A", "B"], stream=False) # -> List[str] + +# Streaming +engine.generate("Hello", stream=True) # -> Generator[str] +engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]] + +# Async +await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] +``` + +> Document Update Time: 2026-05-15 diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md deleted file mode 100644 index 82a570a..0000000 --- a/assets/docs/introduction.md +++ /dev/null @@ -1,334 +0,0 @@ -## Model Introduction - -### 1. Model Architecture - -This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. - -The model now uses the **AutoModel** base class for flexible loading and saving: - -```python -from astrai.model import AutoModel - -# Load model from checkpoint -model = AutoModel.from_pretrained("path/to/model") - -# Save model to new directory -model.save_pretrained("path/to/save") -``` - -The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. - -```mermaid -flowchart TB - subgraph Layers["Transformer Layers"] - direction TB - A[Input Embedding] --> B[Transformer Block\nLayer 1] - B --> C[Transformer Block\nLayer ...] - C --> D[Transformer Block\nLayer ...] - D --> E[RMSNorm] - E --> F[Linear] - F --> G[SoftMax] - end - - subgraph TransformerBlock["Transformer Block"] - direction TB - H[x] --> I[RMSNorm] - I --> J[Linear → Q/K/V] - J --> K[Q] - J --> L[K] - J --> M[V] - K --> N[RoPE] - L --> O[RoPE] - N --> P["Q @ K^T / sqrt(d)"] - O --> P - P --> Q[Masked SoftMax] - Q --> R[S @ V] - M --> R - R --> S[Linear] - S --> T[+] - H --> T - T --> U[RMSNorm] - U --> V["Linear (gate)"] - U --> W["Linear (up)"] - V --> X[SiLU] - X --> Y[×] - W --> Y - Y --> Z["Linear (down)"] - Z --> AA[+] - T --> AA - AA --> BB[x'] - end - - classDef main fill:#e6f3ff,stroke:#0066cc; - classDef block fill:#fff2e6,stroke:#cc6600; - class Layers main; - class TransformerBlock block; -``` - -What is an autoregressive model? After splitting a sentence into tokens, the model predicts the probability distribution of the next token. This means the model calculates the probability of the next possible token and its corresponding probability based on the given context (the sequence of tokens that have already appeared). - -#### 1. Autoregression - -In autoregressive modeling, when a sentence is tokenized into a sequence of tokens, the model learns to predict what comes next. Given a sequence of tokens as input, the model calculates a probability distribution over all possible next tokens. This distribution tells us how likely each potential next token is, given the current context. - -For instance, if the input sequence contains tokens representing a question, the model might predict that certain response tokens have higher probabilities than others. The sampling process then selects one token from this distribution—controlled by parameters like top_k, top_p, and temperature—to serve as the next token in the sequence. - -Once a token is selected, it is appended to the input sequence, and the model repeats this process. The updated sequence is then fed back into the model to predict the next token. This iterative process continues until either a special end-of-sequence token is generated, or the maximum sequence length is reached. These control tokens are essential because without them, the model would continue generating tokens indefinitely, eventually exhausting available memory. - -#### 2. Causal Mask - -Transformers use attention mechanism. The input shape is generally [bsz, seq_len], and the output is [bsz, seq_len, n_dim]. To predict the next token, the model's input and output must be offset by one position. The target predicted by the model must be offset by one position, and during training we also use the offset-by-one method: - -``` -sequence : [[1, 2, 3, 4, 5, 6]] -input_ids: [[1, 2, 3, 4, 5]] -target_ids: [[2, 3, 4, 5, 6]] -``` - -The attention score calculation formula is: - -$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$ -$$ s_{ij} := s_{ij} + mask_{ij} $$ - -Here, the attention score represents the degree to which the model attends to the similarity between two tokens. - -For decoder-only structure models, to prevent the model from "stealing" information from future positions, a mask needs to be added during attention calculation. We need to apply a mask before attention score calculation. This mask is typically a lower triangular matrix, and for a sequence of length n, its shape is [n, n]. Below is an example of how to create such a causal mask matrix for a sequence of length 5: - -``` -[[0, -inf, -inf, -inf, -inf], - [0, 0, -inf, -inf, -inf], - [0, 0, 0, -inf, -inf], - [0, 0, 0, 0, -inf], - [0, 0, 0, 0, 0]] -``` - -In this matrix, 0 represents positions that can be attended to, while -inf represents positions that should be masked (i.e., should not be attended to). Because this matrix ensures that after the softmax, the parts of the attention scores where $j > i$ change from `inf` to 0, meaning the model cannot see future information. - -#### 3. Rotary Position Embedding - -Rotary Position Embedding (RoPE) is a position encoding method designed to solve the problem of lacking direct modeling of sequence position information in Transformer models. Unlike traditional position encodings (such as sine and cosine function position encodings), RoPE embeds position information directly into the Query (Q) and Key (K) vectors, allowing the model to more naturally handle relative position relationships in sequences. - -$$ q_i = R_i W_q x_i $$ -$$ k_j = R_j W_k x_j $$ -$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$ - -The $R_{i-j}$ controls the attenuation of attention for different tokens at different relative distances. When the absolute value of $i - j$ is larger, the degree of attenuation is stronger. This approach allows the model to learn relative position relationships, enabling the model to scale and adapt to longer sequences. - -## KV Cache Implementation - -According to the attention calculation formula: - -$$ -\begin{align*} -o_i &= \sum_j s_{ij} v_{j} \newline -s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right) -\end{align*} -$$ - -Since the model is an autoregressive model, we only need to calculate for the last part of the sequence, meaning the index $i$ is fixed as the last element of the sequence, and we compute $o_{n}$: - -$$ -\begin{align*} -o_n &= \sum_j s_{j}v_{j} \newline -s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right) -\end{align*} -$$ - -If we expand the expression: - -$$ -o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j} -$$ - -In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors. - -### 4. AutoModel Loading - -The project now uses the **AutoModel** base class for flexible model loading and saving: - -```python -from astrai.model import AutoModel - -# Load model from checkpoint -model = AutoModel.from_pretrained("path/to/model") - -# Save model to new directory -model.save_pretrained("path/to/save") -``` - -The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights. - -### 5. Continuous Batching Inference - -The inference engine supports **continuous batching** for efficient batch processing: - -```python -from astrai.inference import InferenceEngine, GenerationRequest - -# Create inference engine with continuous batching -engine = InferenceEngine( - model=model, - tokenizer=tokenizer, -) - -# Use GenerationRequest with messages format -request = GenerationRequest( - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}, - ], - temperature=0.8, - top_p=0.95, - top_k=50, - max_tokens=None, - stream=True, -) - -# Generate with streaming -for token in engine.generate_with_request(request): - print(token, end="", flush=True) -``` - -The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately. - -## HTTP API Usage - -The inference server provides HTTP endpoints for remote inference. Start the server first: - -```bash -python -m scripts.tools.server --port 8000 -``` - -### OpenAI-Compatible Endpoint - -The server provides an OpenAI-compatible chat completion endpoint at `/v1/chat/completions`: - -```bash -curl -X POST http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} - ], - "temperature": 0.8, - "max_tokens": 2048, - "stream": false - }' -``` - -**Request Parameters:** -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `messages` | List[dict] | Required | Chat messages with role and content | -| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) | -| `top_p` | float | 1.0 | Nucleus sampling threshold | -| `top_k` | int | 50 | Top-k sampling parameter | -| `max_tokens` | int | 2048 | Maximum tokens to generate | -| `stream` | bool | false | Enable streaming response | - -**Response (non-streaming):** -```json -{ - "id": "chatcmpl-1234567890", - "object": "chat.completion", - "created": 1234567890, - "model": "astrai", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Hello! I'm doing well..."}, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 20, - "completion_tokens": 15, - "total_tokens": 35 - } -} -``` - -### Streaming Response - -Enable streaming for real-time token-by-token output: - -```bash -curl -X POST http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [{"role": "user", "content": "Write a story"}], - "stream": true, - "max_tokens": 500 - }' -``` - -The server uses Server-Sent Events (SSE) with content type `text/event-stream`. - -### Anthropic-Compatible Endpoint - -The server also provides an Anthropic-compatible endpoint at `/v1/messages`: - -```bash -curl -X POST http://localhost:8000/v1/messages \ - -H "Content-Type: application/json" \ - -d '{ - "model": "astrai", - "system": "You are a helpful assistant.", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "max_tokens": 2048 - }' -``` - -Response: -```json -{ - "id": "msg_abc123...", - "type": "message", - "role": "assistant", - "model": "astrai", - "content": [{"type": "text", "text": "Hello! I am doing well..."}], - "stop_reason": "end_turn", - "stop_sequence": null, - "usage": {"input_tokens": 20, "output_tokens": 15} -} -``` - -Streaming: -```bash -curl -X POST http://localhost:8000/v1/messages \ - -H "Content-Type: application/json" \ - -d '{ - "model": "astrai", - "system": "You are a helpful assistant.", - "messages": [{"role": "user", "content": "Write a short poem"}], - "max_tokens": 500, - "stream": true - }' -``` - -Supports `stop_sequences` for early termination: -```bash -curl -X POST http://localhost:8000/v1/messages \ - -H "Content-Type: application/json" \ - -d '{ - "model": "astrai", - "messages": [{"role": "user", "content": "Write a story"}], - "max_tokens": 500, - "stop_sequences": ["The end", "THE END"] - }' -``` - -### Health Check - -Monitor server and model status: - -```bash -curl http://localhost:8000/health -# {"status": "ok", "model_loaded": true} - -curl http://localhost:8000/stats -# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0} -``` - -> Document Update Time: 2026-05-15 \ No newline at end of file diff --git a/assets/docs/training.md b/assets/docs/training.md new file mode 100644 index 0000000..d17736d --- /dev/null +++ b/assets/docs/training.md @@ -0,0 +1,199 @@ +# Training + +## Model Architecture + +The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, Chinese–English bilingual. + +```mermaid +flowchart TB + subgraph Layers["Transformer Layers"] + direction TB + A[Input Embedding] --> B[Transformer Block\nLayer 1] + B --> C[Transformer Block\nLayer ...] + C --> D[Transformer Block\nLayer ...] + D --> E[RMSNorm] + E --> F[Linear] + F --> G[SoftMax] + end + + subgraph TransformerBlock["Transformer Block"] + direction TB + H[x] --> I[RMSNorm] + I --> J[Linear → Q/K/V] + J --> K[Q]; J --> L[K]; J --> M[V] + K --> N[RoPE]; L --> O[RoPE] + N --> P["Q @ K^T / sqrt(d)"]; O --> P + P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R + R --> S[Linear]; S --> T[+]; H --> T + T --> U[RMSNorm] + U --> V["Linear (gate)"]; U --> W["Linear (up)"] + V --> X[SiLU]; X --> Y[×]; W --> Y + Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA + AA --> BB[x'] + end +``` + +### Autoregression + +Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length. + +### Causal Mask + +``` +sequence : [[1, 2, 3, 4, 5, 6]] +input_ids: [[1, 2, 3, 4, 5]] +target_ids: [[2, 3, 4, 5, 6]] +``` + +Lower-triangular mask prevents attending to future positions: + +``` +[[0, -inf, -inf, -inf, -inf], + [0, 0, -inf, -inf, -inf], + [0, 0, 0, -inf, -inf], + [0, 0, 0, 0, -inf], + [0, 0, 0, 0, 0]] +``` + +### Rotary Position Embedding (RoPE) + +RoPE embeds position into Q/K vectors via complex rotation: + +$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$ + +The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers. + +## Training Loop + +Nested loop: **epoch** → **step** (accumulation window) → **batch**. + +``` +on_train_begin + on_epoch_begin + for steps in batched(dataloader, accumulation_steps): + on_step_begin + step_batch_nums = len(steps) + for batch in steps: + on_batch_begin + loss = strategy(batch) + (loss / step_batch_nums).backward() + iteration += 1 + on_batch_end + on_step_end + optimizer.step() + optimizer.zero_grad() + scheduler.step() + on_epoch_end +on_train_end +``` + +### Callback Lifecycle + +| Hook | Fires | Default callback | +|------|-------|-----------------| +| `on_step_end` | Every accumulation window | `GradientClippingCallback` | +| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | +| `on_train_end` | Training ends | `CheckpointCallback` (final save) | + +Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`. + +## Strategies + +### SEQ (Pre-training) + +Next-token cross-entropy with optional label smoothing: + +$$ +L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta) +$$ + +Keys: `input_ids`, `target_ids` + +### SFT (Supervised Fine-Tuning) + +Masked cross-entropy (`ignore_index=-100`) over response tokens: + +$$ +L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta) +$$ + +Keys: `input_ids`, `target_ids`, `loss_mask` + +### DPO (Direct Preference Optimization) + +Frozen reference model, preference margin via log-ratio: + +$$ +L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right] +$$ + +Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`. + +### GRPO (Group Relative Policy Optimization) + +On-policy PPO with group-normalized advantages: + +$$ +\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon} +$$ + +$$ +L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right] +$$ + +Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`. + +Keys: `prompts`, `responses`, `masks`, `rewards`. + +## LR Schedulers + +| Type | Class | Description | +|------|-------|-------------| +| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` | +| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) | + +Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. + +## Checkpoint + +``` +Checkpoint(state_dict, epoch, iteration, extra) + ├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt + └── load(save_dir) broadcasts metadata from rank-0 +``` + +Optimizer/scheduler state NOT persisted by default; `Checkpoint.extra` can store arbitrary data. + +## TrainContextBuilder (Builder Pattern) + +```python +context = ( + TrainContextBuilder(config) + .with_checkpoint(checkpoint) + .build() +) +# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint +``` + +- Loads checkpoint weights if provided +- Wraps model with `parallel_wrapper` if `nprocs > 1` +- Creates `ResumableDistributedSampler` for shuffle+resume +- Builds strategy via `StrategyFactory.create(train_type, ...)` + +## Training CLI + +```bash +python scripts/tools/train.py \ + --train_type seq \ + --data_root_path /path/to/data \ + --param_path /path/to/model \ + --batch_size 4 \ + --accumulation_steps 8 \ + --max_lr 3e-4 \ + --warmup_steps 1000 \ + --n_epoch 1 +``` + +Full parameter reference at [params.md](params.md). + +> Document Update Time: 2026-05-15