docs: 修复文档中与源码不符的类名、方法签名和模块归属
- CONTRIBUTING.md: ruff/pytest 命令改为 conda 方式 - params.md: max_len → max_tokens - introduction.md: max_len=1024 → max_tokens=None - dataflow.md: PagedCache/CacheView → KVCache/KvcacheView - design.md: 全面修正类图(PagedCache→Allocator等6个新类、删除position_ids误参、修正BaseDataset字段和25+条关系线、Module Overview更新)
This commit is contained in:
parent
205b40bd28
commit
d8da2cf17c
|
|
@ -36,10 +36,10 @@ If you encounter a bug or have a feature request, please open an issue on GitHub
|
||||||
|
|
||||||
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
|
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
|
||||||
|
|
||||||
- Run Ruff to format and lint:
|
- Run Ruff to format and lint (requires conda environment `nlp`):
|
||||||
```bash
|
```bash
|
||||||
ruff format .
|
conda run -n nlp ruff format .
|
||||||
ruff check --fix .
|
conda run -n nlp ruff check --fix .
|
||||||
```
|
```
|
||||||
- The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`).
|
- The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`).
|
||||||
|
|
||||||
|
|
@ -49,7 +49,7 @@ If you add or modify functionality, please include appropriate tests.
|
||||||
|
|
||||||
- Run the test suite with:
|
- Run the test suite with:
|
||||||
```bash
|
```bash
|
||||||
pytest
|
conda run -n nlp python -u -m pytest
|
||||||
```
|
```
|
||||||
- Ensure all tests pass before submitting your PR.
|
- Ensure all tests pass before submitting your PR.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -156,9 +156,9 @@ Background thread runs continuously:
|
||||||
4. Decode → Pick largest same-position group, run single-token forward
|
4. Decode → Pick largest same-position group, run single-token forward
|
||||||
```
|
```
|
||||||
|
|
||||||
- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
|
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)
|
||||||
- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
|
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache
|
||||||
- **`CacheView`**: Batch view bundling cache + page table for attention layers
|
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers
|
||||||
- **`sample()`**: Temperature → top-k → top-p → multinomial
|
- **`sample()`**: Temperature → top-k → top-p → multinomial
|
||||||
|
|
||||||
#### 5.3 Server (`server.py`)
|
#### 5.3 Server (`server.py`)
|
||||||
|
|
@ -216,12 +216,12 @@ Background thread runs continuously:
|
||||||
|
|
||||||
3. **Continuous Batching Loop**
|
3. **Continuous Batching Loop**
|
||||||
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
|
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
|
||||||
- **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
|
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages
|
||||||
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
|
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
|
||||||
- **Decode**: Pick position group with most tasks, single-token forward:
|
- **Decode**: Pick position group with most tasks, single-token forward:
|
||||||
- Model forward → `logits` → `sample()` → next token ID
|
- Model forward → `logits` → `sample()` → next token ID
|
||||||
- Append to `output_ids`, update `output_tokens`
|
- Append to `output_ids`, update `output_tokens`
|
||||||
- `_maybe_alloc_page()` grows page table as needed
|
- `PagePool.task_alloc()` allocates pages as needed
|
||||||
- `stream_callback(token)` for streaming clients
|
- `stream_callback(token)` for streaming clients
|
||||||
|
|
||||||
4. **Output**
|
4. **Output**
|
||||||
|
|
|
||||||
|
|
@ -61,8 +61,8 @@ classDiagram
|
||||||
class BaseDataset {
|
class BaseDataset {
|
||||||
+int window_size
|
+int window_size
|
||||||
+int stride
|
+int stride
|
||||||
+MultiSegmentFetcher fetcher
|
+BaseStorage storage
|
||||||
+load(load_path)
|
+load(load_path, storage_type, tokenizer)
|
||||||
+__getitem__(index)
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
}
|
}
|
||||||
|
|
@ -90,6 +90,26 @@ classDiagram
|
||||||
+fetch_data(begin_idx, end_idx) Tensor
|
+fetch_data(begin_idx, end_idx) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BaseStorage {
|
||||||
|
+Dict segments
|
||||||
|
+List keys
|
||||||
|
+load(load_path, tokenizer)
|
||||||
|
+fetch(begin, end, keys)
|
||||||
|
+__len__()
|
||||||
|
}
|
||||||
|
|
||||||
|
class H5Storage {
|
||||||
|
+load(load_path, tokenizer)
|
||||||
|
+fetch(begin, end, keys) Dict
|
||||||
|
+keys() List
|
||||||
|
}
|
||||||
|
|
||||||
|
class JSONStorage {
|
||||||
|
+load(load_path, tokenizer)
|
||||||
|
+fetch(begin, end, keys) Dict
|
||||||
|
+keys() List
|
||||||
|
}
|
||||||
|
|
||||||
class MultiSegmentFetcher {
|
class MultiSegmentFetcher {
|
||||||
+Dict multi_fetchers
|
+Dict multi_fetchers
|
||||||
+List multi_keys
|
+List multi_keys
|
||||||
|
|
@ -148,7 +168,7 @@ classDiagram
|
||||||
+RMSNorm input_norm
|
+RMSNorm input_norm
|
||||||
+MLP mlp
|
+MLP mlp
|
||||||
+RMSNorm post_attention_norm
|
+RMSNorm post_attention_norm
|
||||||
+forward(x, rotary_emb, attention_mask, position_ids, paged_cache) Tensor
|
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class GQA {
|
class GQA {
|
||||||
|
|
@ -157,7 +177,7 @@ classDiagram
|
||||||
+int head_dim
|
+int head_dim
|
||||||
+Linear q_proj, k_proj, v_proj, o_proj
|
+Linear q_proj, k_proj, v_proj, o_proj
|
||||||
+RMSNorm q_norm, k_norm
|
+RMSNorm q_norm, k_norm
|
||||||
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
|
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class MLA {
|
class MLA {
|
||||||
|
|
@ -170,7 +190,7 @@ classDiagram
|
||||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||||
+Linear o_proj
|
+Linear o_proj
|
||||||
+RMSNorm kv_norm
|
+RMSNorm kv_norm
|
||||||
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
|
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class MLP {
|
class MLP {
|
||||||
|
|
@ -401,7 +421,7 @@ classDiagram
|
||||||
class InferenceScheduler {
|
class InferenceScheduler {
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+PagedCache page_cache
|
+KVCache page_cache
|
||||||
+int max_batch_size
|
+int max_batch_size
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+int max_prompt_len
|
+int max_prompt_len
|
||||||
|
|
@ -415,25 +435,77 @@ classDiagram
|
||||||
+get_stats() Dict
|
+get_stats() Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class PagedCache {
|
class Allocator {
|
||||||
|
+int _free_mask
|
||||||
|
+int refs_count
|
||||||
|
+LRU _lru
|
||||||
|
+alloc() int
|
||||||
|
+free(idx, keep_cached)
|
||||||
|
+inc_ref(idx)
|
||||||
|
+touch(idx)
|
||||||
|
+ref_count(idx) int
|
||||||
|
}
|
||||||
|
|
||||||
|
class PrefixCache {
|
||||||
|
+int _page_size
|
||||||
|
+evict(page_idx)
|
||||||
|
+has_page(idx) bool
|
||||||
|
+lookup(token_ids) List[int]
|
||||||
|
+record(page_idx, token_ids, logical_page_idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
class PagePool {
|
||||||
|
-Allocator _alloc
|
||||||
|
-PrefixCache _prefix
|
||||||
|
+alloc() int
|
||||||
|
+free(idx)
|
||||||
|
+inc_ref(idx)
|
||||||
|
+lookup(token_ids) List[int]
|
||||||
|
+record(page_idx, token_ids, logical_page_idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
class Storage {
|
||||||
|
+int n_layers
|
||||||
+int page_size
|
+int page_size
|
||||||
|
+int head_dim
|
||||||
|
+int n_kv_heads
|
||||||
+Tensor k_cache
|
+Tensor k_cache
|
||||||
+Tensor v_cache
|
+Tensor v_cache
|
||||||
+alloc_n(n) List[int]
|
+write(layer_id, page_table, start_pos, k, v)
|
||||||
+free(idx)
|
|
||||||
+bind(page_table, total_len) CacheView
|
|
||||||
+write(layer_id, page_table, position_ids, k, v)
|
|
||||||
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
|
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
|
||||||
}
|
}
|
||||||
|
|
||||||
class CacheView {
|
class KVCache {
|
||||||
+PagedCache _cache
|
-PagePool _pool
|
||||||
|
-Storage _storage
|
||||||
|
-TaskTable _table
|
||||||
|
+int page_size
|
||||||
|
+task_alloc(task_id, prompt_ids) bool
|
||||||
|
+task_free(task_id)
|
||||||
|
+task_extend(task_id, pos) bool
|
||||||
|
+task_cached(task_id) int
|
||||||
|
+task_record_hashes(task_id, prompt_ids, start_logical_page)
|
||||||
|
+make_table_tensor(task_ids, device) Tensor
|
||||||
|
+bind(page_table, total_len) KvcacheView
|
||||||
|
}
|
||||||
|
|
||||||
|
class KvcacheView {
|
||||||
|
-Storage _storage
|
||||||
+Tensor _page_table
|
+Tensor _page_table
|
||||||
+int _total_len
|
+int _total_len
|
||||||
+write(layer_id, position_ids, k, v)
|
+write(layer_id, k, v)
|
||||||
+gather(layer_id) Tuple[Tensor, Tensor]
|
+gather(layer_id) Tuple[Tensor, Tensor]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TaskTable {
|
||||||
|
+set(task_id, page_table, cached)
|
||||||
|
+get(task_id) List[int]
|
||||||
|
+get_cached(task_id) int
|
||||||
|
+get_ref(task_id) List[int]
|
||||||
|
+pop(task_id) Tuple[List[int], int]
|
||||||
|
+table_tensor(task_ids, device) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
class Task {
|
class Task {
|
||||||
+str task_id
|
+str task_id
|
||||||
+List prompt_ids
|
+List prompt_ids
|
||||||
|
|
@ -445,8 +517,6 @@ classDiagram
|
||||||
+List output_ids
|
+List output_ids
|
||||||
+int input_tokens
|
+int input_tokens
|
||||||
+int output_tokens
|
+int output_tokens
|
||||||
+List[int] page_table
|
|
||||||
+int n_pages
|
|
||||||
+float arrival_time
|
+float arrival_time
|
||||||
+float finish_time
|
+float finish_time
|
||||||
+Callable stream_callback
|
+Callable stream_callback
|
||||||
|
|
@ -464,16 +534,11 @@ classDiagram
|
||||||
|
|
||||||
class GenerationRequest {
|
class GenerationRequest {
|
||||||
+List[Dict] messages
|
+List[Dict] messages
|
||||||
+GenerationParams params
|
|
||||||
+bool stream
|
|
||||||
}
|
|
||||||
|
|
||||||
class GenerationParams {
|
|
||||||
<<value object>>
|
|
||||||
+int top_k
|
+int top_k
|
||||||
+float top_p
|
+float top_p
|
||||||
+float temperature
|
+float temperature
|
||||||
+int max_tokens
|
+Optional[int] max_tokens
|
||||||
|
+bool stream
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseSamplingStrategy {
|
class BaseSamplingStrategy {
|
||||||
|
|
@ -531,9 +596,12 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class ParallelFunctions {
|
class Functions {
|
||||||
+spawn_parallel_fn(fn, nprocs)
|
+spawn_parallel_fn(fn, nprocs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||||
|
+get_current_device() str
|
||||||
|
+get_world_size() int
|
||||||
|
+get_rank() int
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParallelModel {
|
class ParallelModel {
|
||||||
|
|
@ -552,9 +620,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Relationships
|
%% Relationships
|
||||||
TrainConfig --> ModelConfig : uses
|
|
||||||
TrainConfig --> BaseDataset : uses
|
TrainConfig --> BaseDataset : uses
|
||||||
TrainConfig --> StrategyFactory : selects
|
TrainConfig ..> BaseStrategy : selects
|
||||||
StrategyFactory ..> BaseStrategy : creates
|
StrategyFactory ..> BaseStrategy : creates
|
||||||
BaseStrategy <|-- SEQStrategy
|
BaseStrategy <|-- SEQStrategy
|
||||||
BaseStrategy <|-- SFTStrategy
|
BaseStrategy <|-- SFTStrategy
|
||||||
|
|
@ -562,11 +629,12 @@ classDiagram
|
||||||
BaseStrategy <|-- GRPOStrategy
|
BaseStrategy <|-- GRPOStrategy
|
||||||
DPOStrategy --> Transformer : uses
|
DPOStrategy --> Transformer : uses
|
||||||
GRPOStrategy --> Transformer : uses
|
GRPOStrategy --> Transformer : uses
|
||||||
Trainer --> TrainConfig : configures
|
Trainer --> TrainConfig : uses
|
||||||
Trainer --> TrainContextBuilder : builds
|
Trainer --> TrainContextBuilder : uses
|
||||||
Trainer --> TrainCallback : manages
|
Trainer --> TrainCallback : manages
|
||||||
TrainContextBuilder --> TrainContext : creates
|
TrainContextBuilder --> TrainContext : creates
|
||||||
Checkpoint ..> Checkpoint : saves/loads
|
TrainContextBuilder --> StrategyFactory : uses
|
||||||
|
Checkpoint ..> Checkpoint : serializes
|
||||||
TrainContext --> Checkpoint : manages
|
TrainContext --> Checkpoint : manages
|
||||||
TrainContext --> BaseStrategy : uses
|
TrainContext --> BaseStrategy : uses
|
||||||
TrainContext --> BaseScheduler : uses
|
TrainContext --> BaseScheduler : uses
|
||||||
|
|
@ -579,16 +647,21 @@ classDiagram
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
|
PagePool --> Allocator : composes
|
||||||
|
PagePool --> PrefixCache : composes
|
||||||
|
KVCache --> PagePool : composes
|
||||||
|
KVCache --> Storage : composes
|
||||||
|
KVCache --> TaskTable : composes
|
||||||
|
KvcacheView --> Storage : wraps
|
||||||
InferenceEngine --> InferenceScheduler : uses
|
InferenceEngine --> InferenceScheduler : uses
|
||||||
InferenceEngine --> GenerationRequest : uses
|
InferenceEngine --> GenerationRequest : uses
|
||||||
GenerationRequest --> GenerationParams : contains
|
InferenceEngine --> GenerateResult : creates
|
||||||
InferenceScheduler --> Task : manages
|
InferenceScheduler --> Task : manages
|
||||||
Task --> TaskStatus : uses
|
|
||||||
InferenceScheduler --> TaskStatus : uses
|
InferenceScheduler --> TaskStatus : uses
|
||||||
InferenceScheduler --> PagedCache : uses
|
InferenceScheduler --> KVCache : uses
|
||||||
InferenceScheduler --> Transformer : uses
|
InferenceScheduler --> Transformer : uses
|
||||||
|
Task --> TaskStatus : uses
|
||||||
InferenceEngine --> Transformer : uses
|
InferenceEngine --> Transformer : uses
|
||||||
InferenceEngine --> GenerateResult : uses
|
|
||||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
|
|
@ -598,8 +671,10 @@ classDiagram
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
BaseDataset <|-- GRPODataset
|
BaseDataset <|-- GRPODataset
|
||||||
DatasetFactory ..> BaseDataset : creates
|
DatasetFactory ..> BaseDataset : creates
|
||||||
|
BaseStorage <|-- H5Storage
|
||||||
|
BaseStorage <|-- JSONStorage
|
||||||
|
BaseDataset --> BaseStorage : uses
|
||||||
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
||||||
BaseDataset --> MultiSegmentFetcher : uses
|
|
||||||
AutoModel <|-- Transformer
|
AutoModel <|-- Transformer
|
||||||
AutoModel --> ModelConfig : contains
|
AutoModel --> ModelConfig : contains
|
||||||
Transformer --> DecoderBlock : uses
|
Transformer --> DecoderBlock : uses
|
||||||
|
|
@ -613,10 +688,7 @@ classDiagram
|
||||||
ParallelModel <|-- RowParallelLinear
|
ParallelModel <|-- RowParallelLinear
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoTokenizer --> ChatTemplate : uses
|
AutoTokenizer --> ChatTemplate : uses
|
||||||
TrainConfig --> DatasetFactory : selects
|
BaseFactory <|-- AutoModel
|
||||||
TrainConfig --> SchedulerFactory : selects
|
|
||||||
TrainConfig --> CallbackFactory : selects
|
|
||||||
AutoModel ..> AutoTokenizer : loads with
|
|
||||||
BaseFactory <|-- DatasetFactory
|
BaseFactory <|-- DatasetFactory
|
||||||
BaseFactory <|-- StrategyFactory
|
BaseFactory <|-- StrategyFactory
|
||||||
BaseFactory <|-- SchedulerFactory
|
BaseFactory <|-- SchedulerFactory
|
||||||
|
|
@ -628,13 +700,13 @@ classDiagram
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
|
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||||
| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||||
|
|
||||||
### Design Patterns
|
### Design Patterns
|
||||||
|
|
@ -647,7 +719,7 @@ classDiagram
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
|
||||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
|
|
@ -656,10 +728,10 @@ classDiagram
|
||||||
|
|
||||||
### Core Relationships
|
### Core Relationships
|
||||||
|
|
||||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||||
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,7 @@ request = GenerationRequest(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
max_len=1024,
|
max_tokens=None,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ python scripts/tools/train.py \
|
||||||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||||
| `top_k` | Top-k sampling count | 50 |
|
| `top_k` | Top-k sampling count | 50 |
|
||||||
| `max_len` | Maximum generation length | 1024 |
|
| `max_tokens` | Maximum generation length | None (unlimited) |
|
||||||
| `stream` | Whether to stream output | False |
|
| `stream` | Whether to stream output | False |
|
||||||
|
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
@ -130,7 +130,7 @@ request = GenerationRequest(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
max_len=1024,
|
max_tokens=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate (streaming)
|
# Generate (streaming)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue