docs: 修复文档中与源码不符的类名、方法签名和模块归属

- CONTRIBUTING.md: ruff/pytest 命令改为 conda 方式
- params.md: max_len → max_tokens
- introduction.md: max_len=1024 → max_tokens=None
- dataflow.md: PagedCache/CacheView → KVCache/KvcacheView
- design.md: 全面修正类图(PagedCache→Allocator等6个新类、删除position_ids误参、修正BaseDataset字段和25+条关系线、Module Overview更新)
This commit is contained in:
ViperEkura 2026-05-14 20:26:02 +08:00
parent 205b40bd28
commit d8da2cf17c
5 changed files with 129 additions and 57 deletions

View File

@ -36,10 +36,10 @@ If you encounter a bug or have a feature request, please open an issue on GitHub
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
- Run Ruff to format and lint:
- Run Ruff to format and lint (requires conda environment `nlp`):
```bash
ruff format .
ruff check --fix .
conda run -n nlp ruff format .
conda run -n nlp ruff check --fix .
```
- The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`).
@ -49,7 +49,7 @@ If you add or modify functionality, please include appropriate tests.
- Run the test suite with:
```bash
pytest
conda run -n nlp python -u -m pytest
```
- Ensure all tests pass before submitting your PR.

View File

@ -156,9 +156,9 @@ Background thread runs continuously:
4. Decode → Pick largest same-position group, run single-token forward
```
- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
- **`CacheView`**: Batch view bundling cache + page table for attention layers
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers
- **`sample()`**: Temperature → top-k → top-p → multinomial
#### 5.3 Server (`server.py`)
@ -216,13 +216,13 @@ Background thread runs continuously:
3. **Continuous Batching Loop**
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
- **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
- **Decode**: Pick position group with most tasks, single-token forward:
- Model forward → `logits``sample()` → next token ID
- Append to `output_ids`, update `output_tokens`
- `_maybe_alloc_page()` grows page table as needed
- `stream_callback(token)` for streaming clients
- `PagePool.task_alloc()` allocates pages as needed
- `stream_callback(token)` for streaming clients
4. **Output**
- `tokenizer.decode(output_ids)` → text

View File

@ -61,8 +61,8 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+MultiSegmentFetcher fetcher
+load(load_path)
+BaseStorage storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
}
@ -90,6 +90,26 @@ classDiagram
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+Dict segments
+List keys
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
@ -148,7 +168,7 @@ classDiagram
+RMSNorm input_norm
+MLP mlp
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, position_ids, paged_cache) Tensor
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
}
class GQA {
@ -157,7 +177,7 @@ classDiagram
+int head_dim
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLA {
@ -170,7 +190,7 @@ classDiagram
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+RMSNorm kv_norm
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLP {
@ -401,7 +421,7 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+PagedCache page_cache
+KVCache page_cache
+int max_batch_size
+int max_seq_len
+int max_prompt_len
@ -415,25 +435,77 @@ classDiagram
+get_stats() Dict
}
class PagedCache {
class Allocator {
+int _free_mask
+int refs_count
+LRU _lru
+alloc() int
+free(idx, keep_cached)
+inc_ref(idx)
+touch(idx)
+ref_count(idx) int
}
class PrefixCache {
+int _page_size
+evict(page_idx)
+has_page(idx) bool
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class PagePool {
-Allocator _alloc
-PrefixCache _prefix
+alloc() int
+free(idx)
+inc_ref(idx)
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class Storage {
+int n_layers
+int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache
+Tensor v_cache
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, position_ids, k, v)
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
}
class CacheView {
+PagedCache _cache
class KVCache {
-PagePool _pool
-Storage _storage
-TaskTable _table
+int page_size
+task_alloc(task_id, prompt_ids) bool
+task_free(task_id)
+task_extend(task_id, pos) bool
+task_cached(task_id) int
+task_record_hashes(task_id, prompt_ids, start_logical_page)
+make_table_tensor(task_ids, device) Tensor
+bind(page_table, total_len) KvcacheView
}
class KvcacheView {
-Storage _storage
+Tensor _page_table
+int _total_len
+write(layer_id, position_ids, k, v)
+write(layer_id, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
class TaskTable {
+set(task_id, page_table, cached)
+get(task_id) List[int]
+get_cached(task_id) int
+get_ref(task_id) List[int]
+pop(task_id) Tuple[List[int], int]
+table_tensor(task_ids, device) Tensor
}
class Task {
+str task_id
+List prompt_ids
@ -445,8 +517,6 @@ classDiagram
+List output_ids
+int input_tokens
+int output_tokens
+List[int] page_table
+int n_pages
+float arrival_time
+float finish_time
+Callable stream_callback
@ -464,16 +534,11 @@ classDiagram
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
+Optional[int] max_tokens
+bool stream
}
class BaseSamplingStrategy {
@ -531,9 +596,12 @@ classDiagram
}
namespace parallel {
class ParallelFunctions {
class Functions {
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
}
class ParallelModel {
@ -552,9 +620,8 @@ classDiagram
}
%% Relationships
TrainConfig --> ModelConfig : uses
TrainConfig --> BaseDataset : uses
TrainConfig --> StrategyFactory : selects
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
@ -562,11 +629,12 @@ classDiagram
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds
Trainer --> TrainConfig : uses
Trainer --> TrainContextBuilder : uses
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContextBuilder --> StrategyFactory : uses
Checkpoint ..> Checkpoint : serializes
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
@ -579,16 +647,21 @@ classDiagram
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
PagePool --> Allocator : composes
PagePool --> PrefixCache : composes
KVCache --> PagePool : composes
KVCache --> Storage : composes
KVCache --> TaskTable : composes
KvcacheView --> Storage : wraps
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceEngine --> GenerateResult : creates
InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> KVCache : uses
InferenceScheduler --> Transformer : uses
Task --> TaskStatus : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> GenerateResult : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
@ -598,8 +671,10 @@ classDiagram
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseDataset --> BaseStorage : uses
MultiSegmentFetcher --> BaseSegmentFetcher : uses
BaseDataset --> MultiSegmentFetcher : uses
AutoModel <|-- Transformer
AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses
@ -613,10 +688,7 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
AutoModel ..> AutoTokenizer : loads with
BaseFactory <|-- AutoModel
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
@ -628,13 +700,13 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
### Design Patterns
@ -647,7 +719,7 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
@ -656,10 +728,10 @@ classDiagram
### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -180,7 +180,7 @@ request = GenerationRequest(
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=1024,
max_tokens=None,
stream=True,
)

View File

@ -98,7 +98,7 @@ python scripts/tools/train.py \
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_len` | Maximum generation length | 1024 |
| `max_tokens` | Maximum generation length | None (unlimited) |
| `stream` | Whether to stream output | False |
### Usage Example
@ -130,7 +130,7 @@ request = GenerationRequest(
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=1024,
max_tokens=None,
)
# Generate (streaming)