docs: 更新文档以匹配分页 KV cache 等代码重构
This commit is contained in:
parent
f81e2b4a73
commit
9d96b0431d
|
|
@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
|
|||
AstrAI adopts a modular design with the following main components:
|
||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
|
||||
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
|
||||
|
||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||
|
||||
|
|
@ -49,9 +49,9 @@ flowchart LR
|
|||
C3 --> C4[GenerationRequest + apply_chat_template]
|
||||
C4 --> C5[InferenceEngine]
|
||||
C5 --> C6[InferenceScheduler]
|
||||
C6 --> C7[apply_sampling_strategies]
|
||||
C6 --> C7[sample]
|
||||
C7 --> C8[Transformer Forward]
|
||||
C8 --> C9[KV Cache + Prefix Cache]
|
||||
C8 --> C9[Paged KV Cache]
|
||||
C9 --> C10{End Condition?}
|
||||
C10 -->|No| C8
|
||||
C10 -->|Yes| C11[Output Text]
|
||||
|
|
@ -63,27 +63,28 @@ flowchart LR
|
|||
|
||||
## Detailed Module Descriptions
|
||||
|
||||
### 1. Dataset Module
|
||||
### 1. Serialization (`astrai/serialization.py`)
|
||||
|
||||
#### 1.1 Serialization (`serialization.py`)
|
||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
||||
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
||||
|
||||
#### 1.2 Dataset (`dataset.py`)
|
||||
### 2. Dataset Module
|
||||
|
||||
#### 2.1 Dataset (`dataset.py`)
|
||||
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
||||
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
||||
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
||||
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
||||
|
||||
#### 1.3 Sampler (`sampler.py`)
|
||||
#### 2.2 Sampler (`sampler.py`)
|
||||
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
||||
- Records current epoch and iteration position, enabling training resume from breakpoints
|
||||
- Supports shuffle and drop_last options
|
||||
|
||||
### 2. Model Module
|
||||
### 3. Model Module
|
||||
|
||||
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
||||
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
||||
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
||||
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
||||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
||||
|
|
@ -91,7 +92,7 @@ flowchart LR
|
|||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
||||
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
||||
|
||||
#### 2.2 Submodules (`module.py`)
|
||||
#### 3.2 Submodules (`module.py`)
|
||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
||||
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
||||
- **`GQA`**: Grouped Query Attention implementation
|
||||
|
|
@ -100,19 +101,19 @@ flowchart LR
|
|||
- **`RMSNorm`**: Layer normalization variant
|
||||
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
||||
|
||||
### 3. Training Module
|
||||
### 4. Training Module
|
||||
|
||||
#### 3.1 Training Context (`train_context.py`)
|
||||
#### 4.1 Training Context (`train_context.py`)
|
||||
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
|
||||
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
|
||||
|
||||
#### 3.2 Trainer (`trainer.py`)
|
||||
#### 4.2 Trainer (`trainer.py`)
|
||||
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
|
||||
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
|
||||
- Training steps include:
|
||||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
||||
|
||||
#### 3.3 Strategy (`strategy.py`)
|
||||
#### 4.3 Strategy (`strategy.py`)
|
||||
- **`BaseStrategy`**: Defines training strategy interface
|
||||
- **`SEQStrategy`**: Standard next-token prediction training
|
||||
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
||||
|
|
@ -121,14 +122,14 @@ flowchart LR
|
|||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
||||
- Created dynamically by `StrategyFactory` according to configuration
|
||||
|
||||
#### 3.4 Scheduler (`schedule.py`)
|
||||
#### 4.4 Scheduler (`schedule.py`)
|
||||
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
||||
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
||||
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
||||
- Scheduler is automatically created according to configuration and bound to optimizer
|
||||
|
||||
#### 3.5 Callbacks (`train_callback.py`)
|
||||
#### 4.5 Callbacks (`train_callback.py`)
|
||||
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
||||
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
||||
- **`ProgressBarCallback`**: Displays training progress
|
||||
|
|
@ -136,17 +137,21 @@ flowchart LR
|
|||
- **`GradientClippingCallback`**: Clips gradient norms
|
||||
- **`SchedulerCallback`**: Steps learning rate scheduler
|
||||
|
||||
### 4. Factory Module
|
||||
#### 4.6 Metric Utility (`metric_util.py`)
|
||||
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
|
||||
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
|
||||
|
||||
#### 4.1 Registry and BaseFactory (`factory.py`)
|
||||
### 5. Factory Module
|
||||
|
||||
#### 5.1 Registry and BaseFactory (`factory.py`)
|
||||
- **`Registry`**: Flexible registry for component classes with category and priority support
|
||||
- **`BaseFactory`**: Generic factory class for component registration and creation
|
||||
- Supports decorator-based registration pattern for extensible components
|
||||
- Provides methods for registration, retrieval, and listing with filtering
|
||||
|
||||
### 5. Parallel Module
|
||||
### 6. Parallel Module
|
||||
|
||||
#### 5.1 Setup (`setup.py`)
|
||||
#### 6.1 Setup (`setup.py`)
|
||||
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
||||
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
||||
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
||||
|
|
@ -154,47 +159,51 @@ flowchart LR
|
|||
- **`get_world_size`**: Returns total number of processes in distributed group
|
||||
- **`get_current_device`**: Returns current device from environment
|
||||
|
||||
#### 5.2 Parallel Layers (`module.py`)
|
||||
#### 6.2 Parallel Layers (`module.py`)
|
||||
- **`ParallelModel`**: Base class for parallel models with process group
|
||||
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
||||
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
||||
|
||||
### 6. Inference Module
|
||||
### 7. Inference Module
|
||||
|
||||
#### 6.1 Inference Engine (`engine.py`)
|
||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
||||
#### 7.1 Inference Engine (`engine.py`)
|
||||
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
|
||||
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
|
||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
|
||||
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
||||
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
|
||||
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
||||
- Uses separate model and tokenizer initialization for flexibility
|
||||
|
||||
#### 6.2 Scheduler (`scheduler.py`)
|
||||
#### 7.2 Cache (`cache.py`)
|
||||
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
|
||||
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
|
||||
|
||||
#### 7.3 Scheduler (`scheduler.py`)
|
||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||
- **`TaskStatus`**: Task state enumeration
|
||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
||||
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
|
||||
- **`_RadixNode`**: Tree node structure for prefix caching
|
||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
||||
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
|
||||
- Uses `PagedCache` for paged KV cache management with page table indirection
|
||||
- Continuous batching: new requests can join at any time, completed requests release pages immediately
|
||||
|
||||
#### 6.3 Server (`server.py`)
|
||||
#### 7.4 Server (`server.py`)
|
||||
- FastAPI-based HTTP inference server
|
||||
- OpenAI-compatible `/v1/chat/completions` endpoint
|
||||
- Health check and statistics endpoints
|
||||
- Supports both streaming and non-streaming responses
|
||||
|
||||
### 7. Tokenizer Module
|
||||
### 8. Tokenizer Module
|
||||
|
||||
#### 7.1 Tokenizer (`tokenizer.py`)
|
||||
#### 8.1 Tokenizer (`tokenizer.py`)
|
||||
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
||||
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
||||
|
||||
#### 7.2 Chat Template (`chat_template.py`)
|
||||
#### 8.2 Chat Template (`chat_template.py`)
|
||||
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
||||
- Handles multi-role message formatting (system, user, assistant)
|
||||
- Supports dynamic prompts and generation prompts
|
||||
|
|
@ -244,13 +253,14 @@ flowchart LR
|
|||
- For batch generation, use `pad_sequence` for padding
|
||||
|
||||
3. **Autoregressive Generation Loop**
|
||||
- Initialize KV cache (optional) and prefix cache
|
||||
- Loop until generating `max_len` tokens or encountering stop token:
|
||||
- Input current `input_ids` (or cached new token) to model, obtain `logits`
|
||||
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
|
||||
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
|
||||
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
|
||||
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
|
||||
- Input last token ID to model, obtain `logits`
|
||||
- Apply `sample()` (temperature, top-k, top-p) to `logits`
|
||||
- Sample next token ID from the processed distribution
|
||||
- Append new token to `input_ids`, while updating KV cache
|
||||
- For streaming generation, yield each token to caller immediately
|
||||
- Write new KV entries into paged cache; allocate additional pages as needed
|
||||
- For streaming generation, yield each token to caller immediately via `stream_callback`
|
||||
|
||||
4. **Decoding and Output**
|
||||
- Decode generated token ID sequence to text through tokenizer
|
||||
|
|
@ -264,6 +274,6 @@ flowchart LR
|
|||
|
||||
## Summary
|
||||
|
||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
@ -109,7 +109,9 @@ classDiagram
|
|||
+create(train_type, window_size, stride) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||
}
|
||||
}
|
||||
|
||||
namespace serialization {
|
||||
class Checkpoint {
|
||||
+dict state_dict
|
||||
+int epoch
|
||||
|
|
@ -390,10 +392,9 @@ classDiagram
|
|||
+InferenceScheduler scheduler
|
||||
+int max_batch_size
|
||||
+Optional int max_seq_len
|
||||
+int max_prompt_len
|
||||
+int cache_capacity
|
||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
||||
+get_stats() Dict
|
||||
+shutdown()
|
||||
}
|
||||
|
|
@ -401,10 +402,11 @@ classDiagram
|
|||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+Tuple kv_cache
|
||||
+Tensor seq_mask
|
||||
+PrefixCacheManager prefix_cache
|
||||
+SlotAllocator slot_allocator
|
||||
+PagedCache page_cache
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
+int page_size
|
||||
+List waiting_queue
|
||||
+List active_tasks
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
|
|
@ -414,23 +416,26 @@ classDiagram
|
|||
+get_stats() Dict
|
||||
}
|
||||
|
||||
class PrefixCacheManager {
|
||||
+_RadixNode root
|
||||
+int max_capacity
|
||||
+OrderedDict _lru
|
||||
+insert(token_ids, slot, slot_ver)
|
||||
+find(token_ids) Tuple[int, int, int]
|
||||
+pin(token_ids)
|
||||
+release(token_ids)
|
||||
+copy_kv(token_ids, target_slot, kv_cache, n_layers)
|
||||
class PagedCache {
|
||||
+int page_size
|
||||
+int _free_mask
|
||||
+List[int] _refs
|
||||
+Tensor k_cache
|
||||
+Tensor v_cache
|
||||
+alloc() int
|
||||
+alloc_n(n) List[int]
|
||||
+free(idx)
|
||||
+bind(page_table, total_len) CacheView
|
||||
+write(layer_id, page_table, start_pos, k, v)
|
||||
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class _RadixNode {
|
||||
+Dict children
|
||||
+int slot
|
||||
+int slot_ver
|
||||
+int ref_count
|
||||
+float last_access
|
||||
class CacheView {
|
||||
+PagedCache _cache
|
||||
+Tensor _page_table
|
||||
+int _total_len
|
||||
+write(layer_id, start_pos, k, v)
|
||||
+gather(layer_id) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class Task {
|
||||
|
|
@ -444,11 +449,12 @@ classDiagram
|
|||
+List output_ids
|
||||
+int input_tokens
|
||||
+int output_tokens
|
||||
+int slot
|
||||
+int prefix_len
|
||||
+List[int] page_table
|
||||
+int n_pages
|
||||
+float arrival_time
|
||||
+float finish_time
|
||||
+Callable stream_callback
|
||||
+next_pos() int
|
||||
+is_finished(stop_ids) bool
|
||||
}
|
||||
|
||||
|
|
@ -474,17 +480,6 @@ classDiagram
|
|||
+int max_tokens
|
||||
}
|
||||
|
||||
class SlotAllocator {
|
||||
+int _max_slots
|
||||
+int _free_mask
|
||||
+List _versions
|
||||
+alloc() int
|
||||
+free(idx)
|
||||
+occupy(idx)
|
||||
+is_free(idx) bool
|
||||
+version(idx) int
|
||||
}
|
||||
|
||||
class BaseSamplingStrategy {
|
||||
<<abstract>>
|
||||
+apply(logits, filter_value) Tensor
|
||||
|
|
@ -508,6 +503,7 @@ classDiagram
|
|||
class SamplingPipeline {
|
||||
+List strategies
|
||||
+apply(logits, filter_value) Tensor
|
||||
+sample(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class Server {
|
||||
|
|
@ -521,6 +517,8 @@ classDiagram
|
|||
+List[bool] done_flags
|
||||
+append(token, idx)
|
||||
+get_results() List[str]
|
||||
+pop_all() List[str]
|
||||
+wait(timeout) bool
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
|
|
@ -535,15 +533,8 @@ classDiagram
|
|||
+int top_k
|
||||
+int max_tokens
|
||||
+bool stream
|
||||
+Optional[str] system_prompt
|
||||
}
|
||||
|
||||
class CompletionResponse {
|
||||
+str id
|
||||
+str object
|
||||
+int created
|
||||
+str model
|
||||
+List[Dict] choices
|
||||
+Optional[str] stop
|
||||
+Optional[int] n
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -583,6 +574,7 @@ classDiagram
|
|||
Trainer --> TrainContextBuilder : builds
|
||||
Trainer --> TrainCallback : manages
|
||||
TrainContextBuilder --> TrainContext : creates
|
||||
Checkpoint ..> Checkpoint : saves/loads
|
||||
TrainContext --> Checkpoint : manages
|
||||
TrainContext --> BaseStrategy : uses
|
||||
TrainContext --> BaseScheduler : uses
|
||||
|
|
@ -601,7 +593,7 @@ classDiagram
|
|||
InferenceScheduler --> Task : manages
|
||||
Task --> TaskStatus : uses
|
||||
InferenceScheduler --> TaskStatus : uses
|
||||
InferenceScheduler --> SlotAllocator : uses
|
||||
InferenceScheduler --> PagedCache : uses
|
||||
InferenceScheduler --> Transformer : uses
|
||||
InferenceEngine --> Transformer : uses
|
||||
InferenceEngine --> _Result : uses
|
||||
|
|
@ -612,7 +604,6 @@ classDiagram
|
|||
Server --> InferenceEngine : uses
|
||||
Server --> ChatMessage : uses
|
||||
Server --> ChatCompletionRequest : uses
|
||||
Server --> CompletionResponse : uses
|
||||
ParallelSetup --> Trainer : enables
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
|
|
@ -635,9 +626,6 @@ classDiagram
|
|||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoTokenizer --> ChatTemplate : uses
|
||||
InferenceScheduler --> PrefixCacheManager : uses
|
||||
PrefixCacheManager --> _RadixNode : composes
|
||||
Checkpoint ..> Checkpoint : saves/loads
|
||||
TrainConfig --> DatasetFactory : selects
|
||||
TrainConfig --> SchedulerFactory : selects
|
||||
TrainConfig --> CallbackFactory : selects
|
||||
|
|
@ -653,11 +641,12 @@ classDiagram
|
|||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||
|
||||
|
|
@ -671,7 +660,7 @@ classDiagram
|
|||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Singleton** | `TrainContext` | Training process global state management |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
|
||||
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||
|
|
@ -683,7 +672,7 @@ classDiagram
|
|||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||
|
|
|
|||
Loading…
Reference in New Issue