docs: 修正文档中与源码不符的类名、方法签名和模块归属

- Transformer/DecoderBlock/GQA/RotaryEmbedding forward 签名 start_pos → position_ids

- _Result → GenerateResult

- save_h5/load_h5 从 serialization 移至 dataset 模块

- PagedCache UML 移除内部 PagePool 属性

- 修正 Layer 数不一致(24 vs 32)及 decode 位置分组描述

- 更新文档时间为 2026-05-14
This commit is contained in:
ViperEkura 2026-05-14 15:04:53 +08:00
parent 6269bacfc3
commit a8e2a1ba45
5 changed files with 25 additions and 25 deletions

View File

@ -12,7 +12,7 @@ AstrAI adopts a modular design with the following main components:
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors
## Data Flow Diagram
@ -59,7 +59,7 @@ flowchart LR
## Detailed Module Descriptions
### 1. Serialization (`astrai/serialization.py`)
### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`)
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
@ -234,4 +234,4 @@ Background thread runs continuously:
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
> Document Update Time: 2026-05-09
> Document Update Time: 2026-05-14

View File

@ -138,7 +138,7 @@ classDiagram
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, start_pos) Dict
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor
+load_state_dict(state_dict)
+state_dict()
}
@ -148,7 +148,7 @@ classDiagram
+RMSNorm input_norm
+MLP mlp
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor
+forward(x, rotary_emb, attention_mask, position_ids, paged_cache) Tensor
}
class GQA {
@ -157,7 +157,7 @@ classDiagram
+int head_dim
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
}
class MLA {
@ -170,7 +170,7 @@ classDiagram
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+RMSNorm kv_norm
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
+forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor
}
class MLP {
@ -194,7 +194,7 @@ classDiagram
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple[Tensor, Tensor]
+forward(x, position_ids=None) Tuple[Tensor, Tensor]
}
class Embedding {
@ -417,23 +417,20 @@ classDiagram
class PagedCache {
+int page_size
+int _free_mask
+List[int] _refs
+Tensor k_cache
+Tensor v_cache
+alloc() int
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
+write(layer_id, page_table, position_ids, k, v)
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
}
class CacheView {
+PagedCache _cache
+Tensor _page_table
+int _total_len
+write(layer_id, start_pos, k, v)
+write(layer_id, position_ids, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
@ -505,7 +502,7 @@ classDiagram
+sample(logits, filter_value) Tensor
}
class _Result {
class GenerateResult {
+List[str] tokens
+List[str] results
+List[bool] _done
@ -513,6 +510,7 @@ classDiagram
+get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
+wait_completion()
}
class ChatMessage {
@ -590,7 +588,7 @@ classDiagram
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses
InferenceEngine --> GenerateResult : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
@ -630,8 +628,8 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
@ -654,7 +652,7 @@ classDiagram
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships
@ -716,4 +714,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-04-09
> Document Update Time: 2026-05-14

View File

@ -2,7 +2,7 @@
### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving:
@ -24,7 +24,7 @@ flowchart TB
direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer 32]
C --> D[Transformer Block\nLayer ...]
D --> E[RMSNorm]
E --> F[Linear]
F --> G[SoftMax]
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
```
> Document Update Time: 2026-04-09
> Document Update Time: 2026-05-14

View File

@ -155,4 +155,4 @@ result = engine.generate(
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-04-09
> Document Update Time: 2026-05-14

View File

@ -26,7 +26,9 @@ def processor(
model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine
engine = InferenceEngine(model=model, tokenizer=tokenizer, max_batch_size=batch_size)
engine = InferenceEngine(
model=model, tokenizer=tokenizer, max_batch_size=batch_size
)
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]