diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 7d51dd5..63d5b35 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -2,7 +2,7 @@ name: Bug report about: Create a report to help us improve title: "[BUG]" -labels: enhancement +labels: bug assignees: '' --- diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 9b1bbcf..1de6562 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -16,9 +16,9 @@ Please delete options that are not relevant. Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. ## Checklist: -- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check --fix .`) +- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`) - [ ] I have performed a self-review of my own code -- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] Code is self-documenting (no unnecessary comments) - [ ] I have made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ea86b83..30d5738 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,68 +1,100 @@ # Contributing to AstrAI -Thank you for your interest in contributing to AstrAI! This document provides guidelines and steps for contributing. +Thank you for your interest in contributing! This document provides step-by-step guidelines. -## How to Contribute +## Quick Start -### Reporting Issues -If you encounter a bug or have a feature request, please open an issue on GitHub. Include as much detail as possible: -- A clear description of the problem or request. -- Steps to reproduce (for bugs). -- Your environment (Python version, OS, etc.). +```bash +git clone https://github.com/your-username/AstrAI.git +cd AstrAI +pip install -e ".[dev]" # install with dev dependencies (pytest, ruff) +``` -### Submitting Changes -1. **Fork** the repository. -2. **Clone** your fork: - ```bash - git clone https://github.com/your-username/AstrAI.git - cd AstrAI - ``` -3. **Create a feature branch**: - ```bash - git checkout -b feature/your-feature-name - ``` -4. **Make your changes**. Follow the code style guidelines below. -5. **Commit your changes** with a descriptive commit message: - ```bash - git commit -m "Add: brief description of the change" - ``` -6. **Push** to your fork: - ```bash - git push origin feature/your-feature-name - ``` -7. **Open a Pull Request** (PR) against the `main` branch of the upstream repository. +## Before You Commit -## Code Style +Run the following checks **in order** — CI will reject if any fail. -AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting. +### 1. Format -- Run Ruff to format and lint (requires conda environment `nlp`): - ```bash - conda run -n nlp ruff format . - conda run -n nlp ruff check --fix . - ``` -- The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`). +```bash +ruff format . +``` -## Testing +> **Note**: `ruff format` may rename parameters (e.g. `mask` → `attn_mask`). +> Always review the diff after formatting. -If you add or modify functionality, please include appropriate tests. +### 2. Import sorting -- Run the test suite with: - ```bash - conda run -n nlp python -u -m pytest - ``` -- Ensure all tests pass before submitting your PR. +```bash +ruff check . --select I +``` + +If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI): + +```bash +ruff check . --select I --fix . +ruff format . # re-format after fix +``` + +### 3. Run tests + +```bash +python -u -m pytest tests/ -v +``` + +> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed. + +### 4. (Optional) Full pre-commit check + +If you have Git Bash available: + +```bash +bash scripts/pre_commit.sh +``` + +This runs format check, import sort check, and tests in one go. + +## Commit Style + +``` +fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars) + +- bullet point body (each ~60 chars) +``` + +- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`. +- **Subject line** ends with no period. +- **Body** uses bullet points starting with `-`. +- No `(scope)` parentheses. + +## Common Issues + +| Problem | Cause | Fix | +|---------|-------|-----| +| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` | +| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging | +| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` | +| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually | + +## Submitting Changes + +1. Fork the repo. +2. Create a feature branch: `git checkout -b feat/my-feature` +3. Make changes following the steps above. +4. Commit with the commit style above. +5. Push: `git push origin feat/my-feature` +6. Open a Pull Request against `main`. ## Code Review -All submissions will be reviewed. We may request changes or discuss alternatives. Please be responsive to feedback. +- All PRs are reviewed. We may request changes. +- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI). +- Ensure all tests pass. ## License -By contributing, you agree that your contributions will be licensed under the same [GPL-3.0 License](LICENSE) that covers the project. +By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE). --- -If you have any questions, feel free to ask in the [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue. - -Happy contributing! \ No newline at end of file +Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue. diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 313e4d2..b5b2f18 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe ## Overview AstrAI adopts a modular design with the following main components: -- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools +- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, storage backends - **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation @@ -71,7 +71,8 @@ flowchart LR - **`BaseDataset`**: Abstract base class for windowed sequence sampling - **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range - **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`) -- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO) +- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen"/"rejected"` (DPO), `"prompts"/"responses"/"masks"/"rewards"` (GRPO) +- Storage backends: HDF5 (`.h5`) or JSON (`.json`/`.jsonl`), auto-detected by `detect_format()` #### 2.2 Sampler (`sampler.py`) - **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last @@ -85,8 +86,9 @@ flowchart LR - RoPE position encoding, optional weight tying #### 3.2 Submodules (`module.py`) -- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm -- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention) +- **`DecoderBlock`**: Pre-LN (norm→attention→residual, norm→MLP→residual), uses `AttnFactory` / `FFNFactory` +- **`GQA`**: Grouped Query Attention (q_heads ÷ kv_heads = n_rep) +- **`MLA`**: Multi-head Latent Attention with KV compression (kv_lora_rank) - **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection - **`RotaryEmbedding`**: RoPE complex cache (freqs_cis) - **`RMSNorm`**: Layer normalization @@ -107,10 +109,12 @@ on_train_begin for each accumulation window of batches: ← step phase on_step_begin for each batch in window: ← batch phase - on_batch_begin → strategy(batch) → loss → backward → on_batch_end + on_batch_begin → strategy(batch) → loss + (loss / window_size).backward() → on_batch_end iteration += 1 on_step_end optimizer.step() → zero_grad + scheduler.step() ← per step, not per batch on_epoch_end on_train_end @@ -120,7 +124,7 @@ Key points: - `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook - `on_batch_*` fires every batch, wrapping loss computation - `GradientClippingCallback` fires on `on_step_end` -- LR scheduler steps inline (no `SchedulerCallback` class) +- LR scheduler steps inline (no `SchedulerCallback` class), once per optimizer step #### 4.3 Strategy (`strategy.py`) - **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing @@ -167,7 +171,7 @@ Background thread runs continuously: ### 6. Tokenizer Module -- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template` +- **`AutoTokenizer`**: Wraps HuggingFace `tokenizers.Tokenizer` (not `transformers`); `encode`/`decode`/`apply_chat_template` - **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat ### 7. Factory & Parallel @@ -195,14 +199,14 @@ Background thread runs continuously: - Computes task-specific loss (cross-entropy, DPO, GRPO) 5. **Backward & Accumulation** - - `loss = raw_loss / accumulation_steps` - - `loss.backward()` accumulates gradients + - `stand_loss = loss / step_batch_nums` (divide by actual batch count in this window) + - `stand_loss.backward()` accumulates gradients - Every `accumulation_steps` batches: `optimizer.step()` → `zero_grad()` - - Every batch: `scheduler.step()` updates learning rate + - Every optimizer step: `scheduler.step()` updates learning rate 6. **Checkpoint** - `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations - - Does NOT save optimizer/scheduler state (resume resets those) + - Does NOT save optimizer/scheduler state by default; `Checkpoint.extra` or `save_extra_fn` can store arbitrary additional data ## Inference Data Flow — Detailed Steps @@ -230,8 +234,8 @@ Background thread runs continuously: ## Checkpoint & Serialization -- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted. +- **Training Checkpoint**: safetensors weights + epoch/iteration metadata + optional extras. Optimizer/scheduler state is NOT persisted by default but can be stored via `extra`. - **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format. - **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data. -> Document Update Time: 2026-05-14 +> Document Update Time: 2026-05-15 diff --git a/assets/docs/design.md b/assets/docs/design.md index 550587a..b9fc068 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -22,6 +22,12 @@ classDiagram +int n_kv_heads +bool use_qk_norm +bool use_gated_attention + +str attn_type + +str ffn_type + +int n_routed_experts + +int n_shared_experts + +int n_activated_experts + +str moe_topk_method +load(config_path) ModelConfig +save(config_path) } @@ -42,7 +48,7 @@ classDiagram +int ckpt_interval +int random_seed +int num_workers - +int prefetch_factor + +Optional[int] prefetch_factor +bool pin_memory +int nprocs +str backend @@ -118,8 +124,8 @@ classDiagram } class ResumableDistributedSampler { - +int epoch - +int iter + +int start_epoch + +int start_iter } class DatasetFactory { @@ -135,6 +141,7 @@ classDiagram +dict state_dict +int epoch +int iteration + +dict extra +save(save_dir) +load(save_dir) Checkpoint } @@ -158,15 +165,15 @@ classDiagram +ModuleList layers +RMSNorm norm +Linear lm_head - +forward(input_ids, input_mask, paged_cache, position_ids) Tensor + +forward(input_ids, input_mask, paged_cache, position_ids) Dict +load_state_dict(state_dict) +state_dict() } class DecoderBlock { - +GQA attention + +nn.Module attention # GQA or MLA via AttnFactory +RMSNorm input_norm - +MLP mlp + +nn.Module mlp # MLP or DeepSeekMoE via FFNFactory +RMSNorm post_attention_norm +forward(x, rotary_emb, attention_mask, paged_cache) Tensor } @@ -175,8 +182,12 @@ classDiagram +int n_heads +int n_kv_heads +int head_dim + +int n_rep + +bool use_qk_norm + +bool use_gated_attention +Linear q_proj, k_proj, v_proj, o_proj - +RMSNorm q_norm, k_norm + +Linear gate # only if use_gated_attention + +RMSNorm q_norm, k_norm # only if use_qk_norm +forward(x, rotary_emb, attn_mask, paged_cache) Tensor } @@ -187,8 +198,11 @@ classDiagram +int kv_lora_rank +int qk_nope_head_dim +int qk_rope_head_dim + +int n_rep + +bool use_gated_attention +Linear q_proj, kv_a_proj, kv_b_proj +Linear o_proj + +Linear gate # only if use_gated_attention +RMSNorm kv_norm +forward(x, rotary_emb, attn_mask, paged_cache) Tensor } @@ -198,6 +212,25 @@ classDiagram +forward(x) Tensor } + class DeepSeekMoE { + +int n_routed_experts + +int n_shared_experts + +int n_activated_experts + +str topk_method + +Linear router + +ModuleList shared_experts + +ModuleList routed_experts + +forward(x) Tensor + } + + class AttnFactory { + +create(attn_type, **kwargs) nn.Module + } + + class FFNFactory { + +create(ffn_type, dim, dim_ffn, **kwargs) nn.Module + } + class RMSNorm { +Parameter weight +float norm_eps @@ -206,7 +239,7 @@ classDiagram class Linear { +Parameter weight - +Parameter bias + +Optional[Parameter] bias # only if bias=True +forward(x) Tensor } @@ -365,7 +398,7 @@ classDiagram class GradientClippingCallback { +float max_grad_norm - +on_step_begin(context) + +on_step_end(context) } class CheckpointCallback { @@ -410,15 +443,24 @@ classDiagram +shutdown() } - class InferenceScheduler { - +nn.Module model + class Executor { + +AutoModel model +AutoTokenizer tokenizer + +KVCache page_cache + +execute_prefill(tasks, prompt_len, start_pos) + +execute_decode(tasks) List[int] + } + + class InferenceScheduler { +KVCache _page_cache + +Executor _executor + +TaskManager _task_mgr + +bool _running + +Thread _loop_thread +int max_batch_size +int max_seq_len +int max_prompt_len +int page_size - +TaskManager _task_mgr +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +remove_task(task_id) +start() @@ -428,8 +470,8 @@ classDiagram class Allocator { +int _free_mask - +int refs_count - +LRU _lru + +List[int] _refs + +OrderedDict _lru +alloc() int +free(idx, keep_cached) +inc_ref(idx) @@ -564,9 +606,9 @@ classDiagram +List[bool] _done +append(token, idx) +get_results() List[str] - +pop_all() List[str] + +pop_all() List[Tuple[int, str]] +wait(timeout) bool - +wait_completion() + +wait_completion(timeout) } class ChatMessage { @@ -584,6 +626,65 @@ classDiagram +Optional[str] stop +Optional[int] n } + + class AnthropicMessage { + +str role + +Union[str, List[Dict]] content + } + + class MessagesRequest { + +List[AnthropicMessage] messages + +Optional[str] system + +float temperature + +float top_p + +int top_k + +int max_tokens + +bool stream + +Optional[List[str]] stop_sequences + } + + class ProtocolHandler { + <> + +build_prompt() str + +create_response_id() str + +format_stream_start(ctx) List[str] + +format_stream_token(ctx, token) str + +format_stream_end(ctx) List[str] + +format_non_stream_response(ctx, content) Dict + +handle() Union[StreamingResponse, Dict] + } + + class OpenAIHandler { + +build_prompt() str + +create_response_id() str + } + + class AnthropicHandler { + +List[str] stop_sequences + +build_prompt() str + +create_response_id() str + +on_token(ctx, token, stop_checker) Optional[str] + } + + class StopChecker { + +check(text) Optional[str] + +trim(text, matched) str + } + + class StreamContext { + +str resp_id + +int created + +str model + +int prompt_tokens + +int completion_tokens + +str accumulated + +Optional[str] stop_matched + } + + class app { + <> + +FastAPI app + } } namespace parallel { @@ -610,79 +711,104 @@ classDiagram } } - %% Relationships - TrainConfig --> BaseDataset : uses - TrainConfig ..> BaseStrategy : selects - StrategyFactory ..> BaseStrategy : creates + %% Relationships — UML notation: <|-- generalization, *-- composition, o-- aggregation, --> association, ..> dependency + + %% --- Generalization (inheritance) --- BaseStrategy <|-- SEQStrategy BaseStrategy <|-- SFTStrategy BaseStrategy <|-- DPOStrategy BaseStrategy <|-- GRPOStrategy - DPOStrategy --> Transformer : uses - GRPOStrategy --> Transformer : uses - Trainer --> TrainConfig : uses - Trainer --> TrainContextBuilder : uses - Trainer --> TrainCallback : manages - TrainContextBuilder --> TrainContext : creates - TrainContextBuilder --> StrategyFactory : uses - Checkpoint ..> Checkpoint : serializes - TrainContext --> Checkpoint : manages - TrainContext --> BaseStrategy : uses - TrainContext --> BaseScheduler : uses - SchedulerFactory ..> BaseScheduler : creates BaseScheduler <|-- CosineScheduler BaseScheduler <|-- SGDRScheduler - CallbackFactory ..> TrainCallback : creates TrainCallback <|-- GradientClippingCallback TrainCallback <|-- CheckpointCallback TrainCallback <|-- ProgressBarCallback TrainCallback <|-- MetricLoggerCallback - PagePool --> Allocator : composes - PagePool --> PrefixCache : composes - KVCache --> PagePool : composes - KVCache --> Storage : composes - KVCache --> TaskTable : composes - KvcacheView --> Storage : wraps - InferenceEngine --> InferenceScheduler : uses - InferenceEngine --> GenerationRequest : uses - InferenceEngine --> GenerateResult : creates - InferenceScheduler --> Task : manages - InferenceScheduler --> TaskStatus : uses - InferenceScheduler --> KVCache : uses - InferenceScheduler --> Transformer : uses - Task --> TaskStatus : uses - InferenceEngine --> Transformer : uses - BaseSamplingStrategy <|-- TemperatureStrategy - BaseSamplingStrategy <|-- TopKStrategy - BaseSamplingStrategy <|-- TopPStrategy - SamplingPipeline --> BaseSamplingStrategy : composes BaseDataset <|-- SEQDataset BaseDataset <|-- SFTDataset BaseDataset <|-- DPODataset BaseDataset <|-- GRPODataset - DatasetFactory ..> BaseDataset : creates BaseStorage <|-- H5Storage BaseStorage <|-- JSONStorage - BaseDataset --> BaseStorage : uses - MultiSegmentFetcher --> BaseSegmentFetcher : uses - AutoModel <|-- Transformer - AutoModel --> ModelConfig : contains - Transformer --> DecoderBlock : uses - Transformer --> RotaryEmbedding : uses - Transformer --> Embedding : uses - DecoderBlock --> GQA : uses - DecoderBlock --> MLP : uses - DecoderBlock --> RMSNorm : uses - TrainContextBuilder --> ResumableDistributedSampler : creates - ResumableDistributedSampler --> BaseDataset : samples + BaseSamplingStrategy <|-- TemperatureStrategy + BaseSamplingStrategy <|-- TopKStrategy + BaseSamplingStrategy <|-- TopPStrategy ParallelModel <|-- RowParallelLinear ParallelModel <|-- ColumnParallelLinear - AutoTokenizer --> ChatTemplate : uses + AutoModel <|-- Transformer BaseFactory <|-- AutoModel + BaseFactory <|-- AttnFactory + BaseFactory <|-- FFNFactory BaseFactory <|-- DatasetFactory BaseFactory <|-- StrategyFactory BaseFactory <|-- SchedulerFactory BaseFactory <|-- CallbackFactory + ProtocolHandler <|-- OpenAIHandler + ProtocolHandler <|-- AnthropicHandler + + %% --- Composition (strong ownership, part destroyed with whole) --- + KVCache *-- PagePool + KVCache *-- Storage + KVCache *-- TaskTable + KVCache *-- Allocator + KVCache *-- PrefixCache + InferenceEngine *-- InferenceScheduler + InferenceScheduler *-- KVCache + InferenceScheduler *-- Executor + InferenceScheduler *-- TaskManager + SamplingPipeline *-- BaseSamplingStrategy + TrainContextBuilder *-- TrainContext + Transformer *-- DecoderBlock + Transformer *-- RotaryEmbedding + Transformer *-- Embedding + DecoderBlock *-- RMSNorm + BaseDataset *-- BaseStorage + + %% --- Aggregation (weak ownership) --- + AutoModel o-- ModelConfig + Trainer o-- TrainCallback + TrainContext o-- BaseStrategy + TrainContext o-- BaseScheduler + TrainContext o-- Checkpoint + AutoTokenizer o-- ChatTemplate + KvcacheView o-- Storage + BaseFactory o-- Registry + + %% --- Dependency (uses temporarily) --- + TrainConfig ..> BaseStrategy : selects + StrategyFactory ..> BaseStrategy : creates + SchedulerFactory ..> BaseScheduler : creates + DatasetFactory ..> BaseDataset : creates + CallbackFactory ..> TrainCallback : creates + AttnFactory ..> GQA : creates + AttnFactory ..> MLA : creates + FFNFactory ..> MLP : creates + FFNFactory ..> DeepSeekMoE : creates + DecoderBlock ..> AttnFactory : uses + DecoderBlock ..> FFNFactory : uses + Trainer ..> TrainContextBuilder : uses + Trainer ..> Functions : spawns + TrainContextBuilder ..> StrategyFactory : uses + TrainContextBuilder ..> ResumableDistributedSampler : creates + Checkpoint ..> Checkpoint : serializes + CheckpointCallback ..> Checkpoint : creates + KVCache ..> KvcacheView : binds + InferenceEngine ..> GenerationRequest : uses + InferenceEngine ..> GenerateResult : creates + + %% --- Association (general usage) --- + Trainer --> TrainConfig + DPOStrategy --> Transformer + GRPOStrategy --> Transformer + InferenceScheduler --> Task + InferenceScheduler --> TaskStatus + Task --> TaskStatus + InferenceEngine --> Transformer + Executor --> Transformer + Executor --> AutoTokenizer + TaskManager --> AutoTokenizer + MultiSegmentFetcher --> BaseSegmentFetcher + ResumableDistributedSampler --> BaseDataset ``` ### Module Overview @@ -690,14 +816,14 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| | **astrai.config** | ModelConfig, TrainConfig | Configuration management | -| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management | +| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5, save_json, load_json, create_storage, detect_format | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization and checkpoint management | -| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | +| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | -| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache | -| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel | -| **astrai.factory** | Registry, BaseFactory | Generic component registration | +| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample, ChatMessage, ChatCompletionRequest, AnthropicMessage, MessagesRequest, OpenAIHandler, AnthropicHandler, ProtocolHandler, StreamContext, StopChecker, app, run_server | Inference service with continuous batching and paged KV cache | +| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, only_on_rank, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel | +| **astrai.factory** | Registry, BaseFactory[T] | Generic component registration with decorator pattern | ### Design Patterns @@ -706,7 +832,7 @@ classDiagram | **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | | **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | | **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks | -| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | +| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, gradient clipping, metrics) | | **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction | @@ -715,6 +841,8 @@ classDiagram | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation | +| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | `handle()` template with stream/non-stream branches, protocol-specific format hooks | +| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage`, `_STORAGE_REGISTRY` | Format-agnostic data access with registry-dispatch (HDF5 / JSON) | ### Core Relationships @@ -723,7 +851,7 @@ classDiagram 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer` -6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` +6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 and JSON loading via `BaseStorage` (`H5Storage` / `JSONStorage`) with `BaseSegmentFetcher` and `MultiSegmentFetcher` 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler) 9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration @@ -776,4 +904,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$ Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. -> Document Update Time: 2026-05-14 +> Document Update Time: 2026-05-15 diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md index fde784f..82a570a 100644 --- a/assets/docs/introduction.md +++ b/assets/docs/introduction.md @@ -224,7 +224,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \ | `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) | | `top_p` | float | 1.0 | Nucleus sampling threshold | | `top_k` | int | 50 | Top-k sampling parameter | -| `max_tokens` | int | 1024 | Maximum tokens to generate | +| `max_tokens` | int | 2048 | Maximum tokens to generate | | `stream` | bool | false | Enable streaming response | **Response (non-streaming):** @@ -331,4 +331,4 @@ curl http://localhost:8000/stats # {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0} ``` -> Document Update Time: 2026-05-14 \ No newline at end of file +> Document Update Time: 2026-05-15 \ No newline at end of file diff --git a/assets/docs/params.md b/assets/docs/params.md index 7643f60..2b09de5 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -60,7 +60,7 @@ | Parameter | Description | Default | Used by | |-----------|-------------|---------|---------| | `--dpo_beta` | DPO beta value | 0.1 | `dpo` | -| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` | +| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 (CLI) / 0.0 (strategy default) | `seq`, `sft` | | `--group_size` | GRPO group size | 4 | `grpo` | | `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` | | `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` | @@ -98,7 +98,7 @@ python scripts/tools/train.py \ | `temperature` | Sampling temperature (higher = more random) | 1.0 | | `top_p` | Nucleus sampling threshold | 1.0 | | `top_k` | Top-k sampling count | 50 | -| `max_tokens` | Maximum generation length | None (unlimited) | +| `max_tokens` | Maximum generation length | None (defaults to max_seq_len - prompt_len) | | `stream` | Whether to stream output | False | ### Usage Example @@ -155,4 +155,4 @@ result = engine.generate( | `stream=True` | Streaming output, yields token by token | | `stream=False` | Non-streaming output, returns complete result | -> Document Update Time: 2026-05-14 \ No newline at end of file +> Document Update Time: 2026-05-15 \ No newline at end of file