docs: 修复文档中与源码不符的类名、方法签名和模块归属
- CONTRIBUTING.md: ruff/pytest 命令改为 conda 方式 - params.md: max_len → max_tokens - introduction.md: max_len=1024 → max_tokens=None - dataflow.md: PagedCache/CacheView → KVCache/KvcacheView - design.md: 全面修正类图(PagedCache→Allocator等6个新类、删除position_ids误参、修正BaseDataset字段和25+条关系线、Module Overview更新)
This commit is contained in:
parent
205b40bd28
commit
d8da2cf17c
|
|
@ -36,10 +36,10 @@ If you encounter a bug or have a feature request, please open an issue on GitHub
|
|||
|
||||
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
|
||||
|
||||
- Run Ruff to format and lint:
|
||||
- Run Ruff to format and lint (requires conda environment `nlp`):
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check --fix .
|
||||
conda run -n nlp ruff format .
|
||||
conda run -n nlp ruff check --fix .
|
||||
```
|
||||
- The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`).
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ If you add or modify functionality, please include appropriate tests.
|
|||
|
||||
- Run the test suite with:
|
||||
```bash
|
||||
pytest
|
||||
conda run -n nlp python -u -m pytest
|
||||
```
|
||||
- Ensure all tests pass before submitting your PR.
|
||||
|
||||
|
|
|
|||
|
|
@ -156,9 +156,9 @@ Background thread runs continuously:
|
|||
4. Decode → Pick largest same-position group, run single-token forward
|
||||
```
|
||||
|
||||
- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
|
||||
- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
|
||||
- **`CacheView`**: Batch view bundling cache + page table for attention layers
|
||||
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)
|
||||
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache
|
||||
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers
|
||||
- **`sample()`**: Temperature → top-k → top-p → multinomial
|
||||
|
||||
#### 5.3 Server (`server.py`)
|
||||
|
|
@ -216,13 +216,13 @@ Background thread runs continuously:
|
|||
|
||||
3. **Continuous Batching Loop**
|
||||
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
|
||||
- **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
|
||||
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages
|
||||
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
|
||||
- **Decode**: Pick position group with most tasks, single-token forward:
|
||||
- Model forward → `logits` → `sample()` → next token ID
|
||||
- Append to `output_ids`, update `output_tokens`
|
||||
- `_maybe_alloc_page()` grows page table as needed
|
||||
- `stream_callback(token)` for streaming clients
|
||||
- `PagePool.task_alloc()` allocates pages as needed
|
||||
- `stream_callback(token)` for streaming clients
|
||||
|
||||
4. **Output**
|
||||
- `tokenizer.decode(output_ids)` → text
|
||||
|
|
|
|||
|
|
@ -61,8 +61,8 @@ classDiagram
|
|||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+MultiSegmentFetcher fetcher
|
||||
+load(load_path)
|
||||
+BaseStorage storage
|
||||
+load(load_path, storage_type, tokenizer)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
}
|
||||
|
|
@ -90,6 +90,26 @@ classDiagram
|
|||
+fetch_data(begin_idx, end_idx) Tensor
|
||||
}
|
||||
|
||||
class BaseStorage {
|
||||
+Dict segments
|
||||
+List keys
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys)
|
||||
+__len__()
|
||||
}
|
||||
|
||||
class H5Storage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
}
|
||||
|
||||
class JSONStorage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict multi_fetchers
|
||||
+List multi_keys
|
||||
|
|
@ -148,7 +168,7 @@ classDiagram
|
|||
+RMSNorm input_norm
|
||||
+MLP mlp
|
||||
+RMSNorm post_attention_norm
|
||||
+forward(x, rotary_emb, attention_mask, position_ids, paged_cache) Tensor
|
||||
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
|
||||
}
|
||||
|
||||
class GQA {
|
||||
|
|
@ -157,7 +177,7 @@ classDiagram
|
|||
+int head_dim
|
||||
+Linear q_proj, k_proj, v_proj, o_proj
|
||||
+RMSNorm q_norm, k_norm
|
||||
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
|
||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||
}
|
||||
|
||||
class MLA {
|
||||
|
|
@ -170,7 +190,7 @@ classDiagram
|
|||
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||
+Linear o_proj
|
||||
+RMSNorm kv_norm
|
||||
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
|
||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||
}
|
||||
|
||||
class MLP {
|
||||
|
|
@ -401,7 +421,7 @@ classDiagram
|
|||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+PagedCache page_cache
|
||||
+KVCache page_cache
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
|
|
@ -415,25 +435,77 @@ classDiagram
|
|||
+get_stats() Dict
|
||||
}
|
||||
|
||||
class PagedCache {
|
||||
class Allocator {
|
||||
+int _free_mask
|
||||
+int refs_count
|
||||
+LRU _lru
|
||||
+alloc() int
|
||||
+free(idx, keep_cached)
|
||||
+inc_ref(idx)
|
||||
+touch(idx)
|
||||
+ref_count(idx) int
|
||||
}
|
||||
|
||||
class PrefixCache {
|
||||
+int _page_size
|
||||
+evict(page_idx)
|
||||
+has_page(idx) bool
|
||||
+lookup(token_ids) List[int]
|
||||
+record(page_idx, token_ids, logical_page_idx)
|
||||
}
|
||||
|
||||
class PagePool {
|
||||
-Allocator _alloc
|
||||
-PrefixCache _prefix
|
||||
+alloc() int
|
||||
+free(idx)
|
||||
+inc_ref(idx)
|
||||
+lookup(token_ids) List[int]
|
||||
+record(page_idx, token_ids, logical_page_idx)
|
||||
}
|
||||
|
||||
class Storage {
|
||||
+int n_layers
|
||||
+int page_size
|
||||
+int head_dim
|
||||
+int n_kv_heads
|
||||
+Tensor k_cache
|
||||
+Tensor v_cache
|
||||
+alloc_n(n) List[int]
|
||||
+free(idx)
|
||||
+bind(page_table, total_len) CacheView
|
||||
+write(layer_id, page_table, position_ids, k, v)
|
||||
+write(layer_id, page_table, start_pos, k, v)
|
||||
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class CacheView {
|
||||
+PagedCache _cache
|
||||
class KVCache {
|
||||
-PagePool _pool
|
||||
-Storage _storage
|
||||
-TaskTable _table
|
||||
+int page_size
|
||||
+task_alloc(task_id, prompt_ids) bool
|
||||
+task_free(task_id)
|
||||
+task_extend(task_id, pos) bool
|
||||
+task_cached(task_id) int
|
||||
+task_record_hashes(task_id, prompt_ids, start_logical_page)
|
||||
+make_table_tensor(task_ids, device) Tensor
|
||||
+bind(page_table, total_len) KvcacheView
|
||||
}
|
||||
|
||||
class KvcacheView {
|
||||
-Storage _storage
|
||||
+Tensor _page_table
|
||||
+int _total_len
|
||||
+write(layer_id, position_ids, k, v)
|
||||
+write(layer_id, k, v)
|
||||
+gather(layer_id) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class TaskTable {
|
||||
+set(task_id, page_table, cached)
|
||||
+get(task_id) List[int]
|
||||
+get_cached(task_id) int
|
||||
+get_ref(task_id) List[int]
|
||||
+pop(task_id) Tuple[List[int], int]
|
||||
+table_tensor(task_ids, device) Tensor
|
||||
}
|
||||
|
||||
class Task {
|
||||
+str task_id
|
||||
+List prompt_ids
|
||||
|
|
@ -445,8 +517,6 @@ classDiagram
|
|||
+List output_ids
|
||||
+int input_tokens
|
||||
+int output_tokens
|
||||
+List[int] page_table
|
||||
+int n_pages
|
||||
+float arrival_time
|
||||
+float finish_time
|
||||
+Callable stream_callback
|
||||
|
|
@ -464,16 +534,11 @@ classDiagram
|
|||
|
||||
class GenerationRequest {
|
||||
+List[Dict] messages
|
||||
+GenerationParams params
|
||||
+bool stream
|
||||
}
|
||||
|
||||
class GenerationParams {
|
||||
<<value object>>
|
||||
+int top_k
|
||||
+float top_p
|
||||
+float temperature
|
||||
+int max_tokens
|
||||
+Optional[int] max_tokens
|
||||
+bool stream
|
||||
}
|
||||
|
||||
class BaseSamplingStrategy {
|
||||
|
|
@ -531,9 +596,12 @@ classDiagram
|
|||
}
|
||||
|
||||
namespace parallel {
|
||||
class ParallelFunctions {
|
||||
class Functions {
|
||||
+spawn_parallel_fn(fn, nprocs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
+get_current_device() str
|
||||
+get_world_size() int
|
||||
+get_rank() int
|
||||
}
|
||||
|
||||
class ParallelModel {
|
||||
|
|
@ -552,9 +620,8 @@ classDiagram
|
|||
}
|
||||
|
||||
%% Relationships
|
||||
TrainConfig --> ModelConfig : uses
|
||||
TrainConfig --> BaseDataset : uses
|
||||
TrainConfig --> StrategyFactory : selects
|
||||
TrainConfig ..> BaseStrategy : selects
|
||||
StrategyFactory ..> BaseStrategy : creates
|
||||
BaseStrategy <|-- SEQStrategy
|
||||
BaseStrategy <|-- SFTStrategy
|
||||
|
|
@ -562,11 +629,12 @@ classDiagram
|
|||
BaseStrategy <|-- GRPOStrategy
|
||||
DPOStrategy --> Transformer : uses
|
||||
GRPOStrategy --> Transformer : uses
|
||||
Trainer --> TrainConfig : configures
|
||||
Trainer --> TrainContextBuilder : builds
|
||||
Trainer --> TrainConfig : uses
|
||||
Trainer --> TrainContextBuilder : uses
|
||||
Trainer --> TrainCallback : manages
|
||||
TrainContextBuilder --> TrainContext : creates
|
||||
Checkpoint ..> Checkpoint : saves/loads
|
||||
TrainContextBuilder --> StrategyFactory : uses
|
||||
Checkpoint ..> Checkpoint : serializes
|
||||
TrainContext --> Checkpoint : manages
|
||||
TrainContext --> BaseStrategy : uses
|
||||
TrainContext --> BaseScheduler : uses
|
||||
|
|
@ -579,16 +647,21 @@ classDiagram
|
|||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
PagePool --> Allocator : composes
|
||||
PagePool --> PrefixCache : composes
|
||||
KVCache --> PagePool : composes
|
||||
KVCache --> Storage : composes
|
||||
KVCache --> TaskTable : composes
|
||||
KvcacheView --> Storage : wraps
|
||||
InferenceEngine --> InferenceScheduler : uses
|
||||
InferenceEngine --> GenerationRequest : uses
|
||||
GenerationRequest --> GenerationParams : contains
|
||||
InferenceEngine --> GenerateResult : creates
|
||||
InferenceScheduler --> Task : manages
|
||||
Task --> TaskStatus : uses
|
||||
InferenceScheduler --> TaskStatus : uses
|
||||
InferenceScheduler --> PagedCache : uses
|
||||
InferenceScheduler --> KVCache : uses
|
||||
InferenceScheduler --> Transformer : uses
|
||||
Task --> TaskStatus : uses
|
||||
InferenceEngine --> Transformer : uses
|
||||
InferenceEngine --> GenerateResult : uses
|
||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||
BaseSamplingStrategy <|-- TopKStrategy
|
||||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
|
|
@ -598,8 +671,10 @@ classDiagram
|
|||
BaseDataset <|-- DPODataset
|
||||
BaseDataset <|-- GRPODataset
|
||||
DatasetFactory ..> BaseDataset : creates
|
||||
BaseStorage <|-- H5Storage
|
||||
BaseStorage <|-- JSONStorage
|
||||
BaseDataset --> BaseStorage : uses
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
||||
BaseDataset --> MultiSegmentFetcher : uses
|
||||
AutoModel <|-- Transformer
|
||||
AutoModel --> ModelConfig : contains
|
||||
Transformer --> DecoderBlock : uses
|
||||
|
|
@ -613,10 +688,7 @@ classDiagram
|
|||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoTokenizer --> ChatTemplate : uses
|
||||
TrainConfig --> DatasetFactory : selects
|
||||
TrainConfig --> SchedulerFactory : selects
|
||||
TrainConfig --> CallbackFactory : selects
|
||||
AutoModel ..> AutoTokenizer : loads with
|
||||
BaseFactory <|-- AutoModel
|
||||
BaseFactory <|-- DatasetFactory
|
||||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
|
|
@ -628,13 +700,13 @@ classDiagram
|
|||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||
| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||
|
||||
### Design Patterns
|
||||
|
|
@ -647,7 +719,7 @@ classDiagram
|
|||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
|
||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||
|
|
@ -656,10 +728,10 @@ classDiagram
|
|||
|
||||
### Core Relationships
|
||||
|
||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ request = GenerationRequest(
|
|||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_len=1024,
|
||||
max_tokens=None,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ python scripts/tools/train.py \
|
|||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||
| `top_k` | Top-k sampling count | 50 |
|
||||
| `max_len` | Maximum generation length | 1024 |
|
||||
| `max_tokens` | Maximum generation length | None (unlimited) |
|
||||
| `stream` | Whether to stream output | False |
|
||||
|
||||
### Usage Example
|
||||
|
|
@ -130,7 +130,7 @@ request = GenerationRequest(
|
|||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_len=1024,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
# Generate (streaming)
|
||||
|
|
|
|||
Loading…
Reference in New Issue