docs: 修正文档中与源码不符的类名、方法签名和模块归属

- Transformer/DecoderBlock/GQA/RotaryEmbedding forward 签名 start_pos → position_ids

- _Result → GenerateResult

- save_h5/load_h5 从 serialization 移至 dataset 模块

- PagedCache UML 移除内部 PagePool 属性

- 修正 Layer 数不一致(24 vs 32)及 decode 位置分组描述

- 更新文档时间为 2026-05-14
This commit is contained in:
ViperEkura 2026-05-14 15:04:53 +08:00
parent 6269bacfc3
commit a8e2a1ba45
5 changed files with 25 additions and 25 deletions

View File

@ -12,7 +12,7 @@ AstrAI adopts a modular design with the following main components:
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig - **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management - **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors
## Data Flow Diagram ## Data Flow Diagram
@ -59,7 +59,7 @@ flowchart LR
## Detailed Module Descriptions ## Detailed Module Descriptions
### 1. Serialization (`astrai/serialization.py`) ### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`)
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors - **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
@ -234,4 +234,4 @@ Background thread runs continuously:
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format. - **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data. - **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
> Document Update Time: 2026-05-09 > Document Update Time: 2026-05-14

View File

@ -138,7 +138,7 @@ classDiagram
+ModuleList layers +ModuleList layers
+RMSNorm norm +RMSNorm norm
+Linear lm_head +Linear lm_head
+forward(input_ids, input_mask, paged_cache, start_pos) Dict +forward(input_ids, input_mask, paged_cache, position_ids) Tensor
+load_state_dict(state_dict) +load_state_dict(state_dict)
+state_dict() +state_dict()
} }
@ -148,7 +148,7 @@ classDiagram
+RMSNorm input_norm +RMSNorm input_norm
+MLP mlp +MLP mlp
+RMSNorm post_attention_norm +RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor +forward(x, rotary_emb, attention_mask, position_ids, paged_cache) Tensor
} }
class GQA { class GQA {
@ -157,7 +157,7 @@ classDiagram
+int head_dim +int head_dim
+Linear q_proj, k_proj, v_proj, o_proj +Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm +RMSNorm q_norm, k_norm
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor +forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
} }
class MLA { class MLA {
@ -170,7 +170,7 @@ classDiagram
+Linear q_proj, kv_a_proj, kv_b_proj +Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj +Linear o_proj
+RMSNorm kv_norm +RMSNorm kv_norm
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor +forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
} }
class MLP { class MLP {
@ -194,7 +194,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, start_pos) Tuple[Tensor, Tensor] +forward(x, position_ids=None) Tuple[Tensor, Tensor]
} }
class Embedding { class Embedding {
@ -417,23 +417,20 @@ classDiagram
class PagedCache { class PagedCache {
+int page_size +int page_size
+int _free_mask
+List[int] _refs
+Tensor k_cache +Tensor k_cache
+Tensor v_cache +Tensor v_cache
+alloc() int
+alloc_n(n) List[int] +alloc_n(n) List[int]
+free(idx) +free(idx)
+bind(page_table, total_len) CacheView +bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v) +write(layer_id, page_table, position_ids, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor] +gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
} }
class CacheView { class CacheView {
+PagedCache _cache +PagedCache _cache
+Tensor _page_table +Tensor _page_table
+int _total_len +int _total_len
+write(layer_id, start_pos, k, v) +write(layer_id, position_ids, k, v)
+gather(layer_id) Tuple[Tensor, Tensor] +gather(layer_id) Tuple[Tensor, Tensor]
} }
@ -505,7 +502,7 @@ classDiagram
+sample(logits, filter_value) Tensor +sample(logits, filter_value) Tensor
} }
class _Result { class GenerateResult {
+List[str] tokens +List[str] tokens
+List[str] results +List[str] results
+List[bool] _done +List[bool] _done
@ -513,6 +510,7 @@ classDiagram
+get_results() List[str] +get_results() List[str]
+pop_all() List[str] +pop_all() List[str]
+wait(timeout) bool +wait(timeout) bool
+wait_completion()
} }
class ChatMessage { class ChatMessage {
@ -590,7 +588,7 @@ classDiagram
InferenceScheduler --> PagedCache : uses InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses InferenceEngine --> GenerateResult : uses
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy BaseSamplingStrategy <|-- TopPStrategy
@ -630,8 +628,8 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management | | **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
@ -654,7 +652,7 @@ classDiagram
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation | | **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships ### Core Relationships
@ -716,4 +714,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-04-09 > Document Update Time: 2026-05-14

View File

@ -2,7 +2,7 @@
### 1. Model Architecture ### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving: The model now uses the **AutoModel** base class for flexible loading and saving:
@ -24,7 +24,7 @@ flowchart TB
direction TB direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1] A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...] B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer 32] C --> D[Transformer Block\nLayer ...]
D --> E[RMSNorm] D --> E[RMSNorm]
E --> F[Linear] E --> F[Linear]
F --> G[SoftMax] F --> G[SoftMax]
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0} # {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
``` ```
> Document Update Time: 2026-04-09 > Document Update Time: 2026-05-14

View File

@ -155,4 +155,4 @@ result = engine.generate(
| `stream=True` | Streaming output, yields token by token | | `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result | | `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-04-09 > Document Update Time: 2026-05-14

View File

@ -26,7 +26,9 @@ def processor(
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine # Create inference engine
engine = InferenceEngine(model=model, tokenizer=tokenizer, max_batch_size=batch_size) engine = InferenceEngine(
model=model, tokenizer=tokenizer, max_batch_size=batch_size
)
with open(input_json_file, "r", encoding="utf-8") as f: with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]