diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 6f21d0b..6ab7220 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -12,7 +12,7 @@ AstrAI adopts a modular design with the following main components: - **Config Module** (`astrai/config/`): ModelConfig, TrainConfig - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Parallel Module** (`astrai/parallel/`): Distributed training support -- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management +- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors ## Data Flow Diagram @@ -59,7 +59,7 @@ flowchart LR ## Detailed Module Descriptions -### 1. Serialization (`astrai/serialization.py`) +### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`) - **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory @@ -234,4 +234,4 @@ Background thread runs continuously: - **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format. - **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data. -> Document Update Time: 2026-05-09 +> Document Update Time: 2026-05-14 diff --git a/assets/docs/design.md b/assets/docs/design.md index e8283e7..0a15c84 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -138,7 +138,7 @@ classDiagram +ModuleList layers +RMSNorm norm +Linear lm_head - +forward(input_ids, input_mask, paged_cache, start_pos) Dict + +forward(input_ids, input_mask, paged_cache, position_ids) Tensor +load_state_dict(state_dict) +state_dict() } @@ -148,7 +148,7 @@ classDiagram +RMSNorm input_norm +MLP mlp +RMSNorm post_attention_norm - +forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor + +forward(x, rotary_emb, attention_mask, position_ids, paged_cache) Tensor } class GQA { @@ -157,7 +157,7 @@ classDiagram +int head_dim +Linear q_proj, k_proj, v_proj, o_proj +RMSNorm q_norm, k_norm - +forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor + +forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor } class MLA { @@ -170,7 +170,7 @@ classDiagram +Linear q_proj, kv_a_proj, kv_b_proj +Linear o_proj +RMSNorm kv_norm - +forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor + +forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor } class MLP { @@ -194,7 +194,7 @@ classDiagram +int dim +int max_len +float base - +forward(x, start_pos) Tuple[Tensor, Tensor] + +forward(x, position_ids=None) Tuple[Tensor, Tensor] } class Embedding { @@ -417,23 +417,20 @@ classDiagram class PagedCache { +int page_size - +int _free_mask - +List[int] _refs +Tensor k_cache +Tensor v_cache - +alloc() int +alloc_n(n) List[int] +free(idx) +bind(page_table, total_len) CacheView - +write(layer_id, page_table, start_pos, k, v) - +gather(layer_id, page_table) Tuple[Tensor, Tensor] + +write(layer_id, page_table, position_ids, k, v) + +gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor] } class CacheView { +PagedCache _cache +Tensor _page_table +int _total_len - +write(layer_id, start_pos, k, v) + +write(layer_id, position_ids, k, v) +gather(layer_id) Tuple[Tensor, Tensor] } @@ -505,7 +502,7 @@ classDiagram +sample(logits, filter_value) Tensor } - class _Result { + class GenerateResult { +List[str] tokens +List[str] results +List[bool] _done @@ -513,6 +510,7 @@ classDiagram +get_results() List[str] +pop_all() List[str] +wait(timeout) bool + +wait_completion() } class ChatMessage { @@ -590,7 +588,7 @@ classDiagram InferenceScheduler --> PagedCache : uses InferenceScheduler --> Transformer : uses InferenceEngine --> Transformer : uses - InferenceEngine --> _Result : uses + InferenceEngine --> GenerateResult : uses BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopPStrategy @@ -630,8 +628,8 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| | **astrai.config** | ModelConfig, TrainConfig | Configuration management | -| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | -| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management | +| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management | +| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | @@ -654,7 +652,7 @@ classDiagram | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | -| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation | +| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation | ### Core Relationships @@ -716,4 +714,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$ Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. -> Document Update Time: 2026-04-09 +> Document Update Time: 2026-05-14 diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md index 674080e..e788d30 100644 --- a/assets/docs/introduction.md +++ b/assets/docs/introduction.md @@ -2,7 +2,7 @@ ### 1. Model Architecture -This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. +This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. The model now uses the **AutoModel** base class for flexible loading and saving: @@ -24,7 +24,7 @@ flowchart TB direction TB A[Input Embedding] --> B[Transformer Block\nLayer 1] B --> C[Transformer Block\nLayer ...] - C --> D[Transformer Block\nLayer 32] + C --> D[Transformer Block\nLayer ...] D --> E[RMSNorm] E --> F[Linear] F --> G[SoftMax] @@ -331,4 +331,4 @@ curl http://localhost:8000/stats # {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0} ``` -> Document Update Time: 2026-04-09 \ No newline at end of file +> Document Update Time: 2026-05-14 \ No newline at end of file diff --git a/assets/docs/params.md b/assets/docs/params.md index 62fadd7..5a9edfd 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -155,4 +155,4 @@ result = engine.generate( | `stream=True` | Streaming output, yields token by token | | `stream=False` | Non-streaming output, returns complete result | -> Document Update Time: 2026-04-09 \ No newline at end of file +> Document Update Time: 2026-05-14 \ No newline at end of file diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index bc3cff9..ccb7129 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -26,7 +26,9 @@ def processor( model.to(device="cuda", dtype=torch.bfloat16) # Create inference engine - engine = InferenceEngine(model=model, tokenizer=tokenizer, max_batch_size=batch_size) + engine = InferenceEngine( + model=model, tokenizer=tokenizer, max_batch_size=batch_size + ) with open(input_json_file, "r", encoding="utf-8") as f: input_data = [json.loads(line) for line in f]