diff --git a/README.md b/README.md index b19041d..b049505 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,6 @@ ## 📖 Table of Contents -
-English - - [Features](#features) - [Quick Start](#quick-start) - [Documentation](#documentation) @@ -37,8 +34,6 @@ - [Community](#community) - [License](#license) -
- --- @@ -75,7 +70,14 @@ pip install -e ".[dev]" python scripts/tools/train.py \ --train_type=seq \ --data_root_path=/path/to/dataset \ - --param_path=/path/to/param_path + --param_path=/path/to/model \ + --n_epoch=3 \ + --batch_size=4 \ + --accumulation_steps=8 \ + --max_lr=3e-4 \ + --warmup_steps=2000 \ + --ckpt_interval=5000 \ + --ckpt_dir=./checkpoints ``` #### Generate Text @@ -84,6 +86,25 @@ python scripts/tools/train.py \ python scripts/tools/generate.py --param_path=/path/to/param_path ``` +#### Training Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required | +| `--data_root_path` | Dataset root directory | required | +| `--param_path` | Model / checkpoint path | required | +| `--n_epoch` | Training epochs | 1 | +| `--batch_size` | Batch size | 1 | +| `--accumulation_steps` | Gradient accumulation steps | 1 | +| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 | +| `--warmup_steps` | LR warmup steps | 1000 | +| `--ckpt_interval` | Checkpoint interval (iters) | 5000 | +| `--ckpt_dir` | Checkpoint directory | checkpoint | +| `--num_workers` | DataLoader workers | 4 | +| `--nprocs` | Number of GPUs | 1 | + +Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters). + #### Docker Build and run with Docker (recommended for GPU environments): diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index 72e3d08..ff4dc99 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -76,7 +76,14 @@ pip install -e ".[dev]" python scripts/tools/train.py \ --train_type=seq \ --data_root_path=/path/to/dataset \ - --param_path=/path/to/param_path + --param_path=/path/to/model \ + --n_epoch=3 \ + --batch_size=4 \ + --accumulation_steps=8 \ + --max_lr=3e-4 \ + --warmup_steps=2000 \ + --ckpt_interval=5000 \ + --ckpt_dir=./checkpoints ``` #### 文本生成 @@ -85,6 +92,25 @@ python scripts/tools/train.py \ python scripts/tools/generate.py --param_path=/path/to/param_path ``` +#### 训练参数 + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 | +| `--data_root_path` | 数据集根目录 | 必填 | +| `--param_path` | 模型参数或断点路径 | 必填 | +| `--n_epoch` | 训练轮数 | 1 | +| `--batch_size` | 批次大小 | 1 | +| `--accumulation_steps` | 梯度累积步数 | 1 | +| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 | +| `--warmup_steps` | 预热步数 | 1000 | +| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 | +| `--ckpt_dir` | 检查点保存目录 | checkpoint | +| `--num_workers` | 数据加载线程数 | 4 | +| `--nprocs` | GPU 数量 | 1 | + +完整参数列表见[参数说明](./params.md#training-parameters)。 + #### Docker 使用 Docker 构建和运行(推荐用于 GPU 环境): diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index f26b970..0937d96 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -176,7 +176,7 @@ flowchart LR - **`TaskStatus`**: Task state enumeration - **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits - **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse -- **`RadixNode`**: Tree node structure for prefix caching +- **`_RadixNode`**: Tree node structure for prefix caching - Continuous batching: new requests can join at any time, completed requests are released immediately #### 6.3 Server (`server.py`) diff --git a/assets/docs/design.md b/assets/docs/design.md index 7817e8b..67e9f2f 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -85,8 +85,8 @@ classDiagram } class BaseSegmentFetcher { - +List~Tensor~ segments - +List~int~ cum_lengths + +List[Tensor] segments + +List[int] cum_lengths +int total_length +fetch_data(begin_idx, end_idx) Tensor } @@ -191,7 +191,7 @@ classDiagram +int dim +int max_len +float base - +forward(x, start_pos) Tuple~Tensor, Tensor~ + +forward(x, start_pos) Tuple[Tensor, Tensor] } class Embedding { @@ -202,14 +202,14 @@ classDiagram namespace tokenize { class AutoTokenizer { - +List~str~ stop_ids + +List[str] stop_ids +int bos_id +int eos_id +int pad_id +vocab_size int - +encode(tokens, out_ids, add_special_tokens) List~int~ + +encode(tokens, out_ids, add_special_tokens) List[int] +decode(tokens, skip_special_tokens) str - +apply_chat_template(messages, tokenize) Union~str, List[int]~ + +apply_chat_template(messages, tokenize) Union[str, List[int]] +set_chat_template(template) +load(path) +from_pretrained(path) AutoTokenizer @@ -228,7 +228,7 @@ classDiagram +Dict _entries +register(name, component_cls, category, priority) +get(name) Type - +list_names() List~str~ + +list_names() List[str] } class BaseFactory { @@ -242,10 +242,10 @@ classDiagram namespace trainer { class Trainer { +TrainConfig train_config - +List~TrainCallback~ callbacks + +List[TrainCallback] callbacks +train(checkpoint) +_build_context(checkpoint) TrainContext - +_get_default_callbacks() List~TrainCallback~ + +_get_default_callbacks() List[TrainCallback] } class TrainContext { @@ -308,7 +308,7 @@ classDiagram } class BaseScheduler { - +get_lr() List~float~ + +get_lr() List[float] +step() } @@ -390,10 +390,8 @@ classDiagram +InferenceScheduler scheduler +int max_batch_size +Optional int max_seq_len - +int max_prefix_len + +int max_prompt_len +int cache_capacity - +Tensor kv_cache - +Tensor seq_mask +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]] +get_stats() Dict @@ -403,10 +401,10 @@ classDiagram class InferenceScheduler { +nn.Module model +AutoTokenizer tokenizer - +ModelConfig config +Tuple kv_cache +Tensor seq_mask +PrefixCacheManager prefix_cache + +SlotAllocator slot_allocator +List waiting_queue +List active_tasks +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str @@ -417,21 +415,22 @@ classDiagram } class PrefixCacheManager { - +RadixNode root + +_RadixNode root +int max_capacity - +List lru - +insert(token_ids, slot) - +find_longest_prefix(token_ids) Tuple[int, int] + +OrderedDict _lru + +insert(token_ids, slot, slot_ver) + +find(token_ids) Tuple[int, int, int] + +pin(token_ids) +release(token_ids) + +copy_kv(token_ids, target_slot, kv_cache, n_layers) } - class RadixNode { + class _RadixNode { +Dict children - +int hash +int slot + +int slot_ver +int ref_count +float last_access - +List token_sequence } class Task { @@ -446,15 +445,69 @@ classDiagram +int input_tokens +int output_tokens +int slot + +int prefix_len + +float arrival_time + +float finish_time +Callable stream_callback +is_finished(stop_ids) bool } class TaskStatus { - +str PENDING - +str RUNNING - +str FINISHED - +str ABORTED + <> + PENDING + RUNNING + FINISHED + ABORTED + } + + class GenerationRequest { + +List[Dict] messages + +GenerationParams params + +bool stream + } + + class GenerationParams { + <> + +int top_k + +float top_p + +float temperature + +int max_tokens + } + + class SlotAllocator { + +int _max_slots + +int _free_mask + +List _versions + +alloc() int + +free(idx) + +occupy(idx) + +is_free(idx) bool + +version(idx) int + } + + class BaseSamplingStrategy { + <> + +apply(logits, filter_value) Tensor + } + + class TemperatureStrategy { + +float temperature + +apply(logits, filter_value) Tensor + } + + class TopKStrategy { + +int top_k + +apply(logits, filter_value) Tensor + } + + class TopPStrategy { + +float top_p + +apply(logits, filter_value) Tensor + } + + class SamplingPipeline { + +List strategies + +apply(logits, filter_value) Tensor } class Server { @@ -462,21 +515,12 @@ classDiagram +predict(request) } - class GenerationRequest { - +int top_k - +float top_p - +float temperature - +int max_len - +List~Dict~ messages - +stream bool - } - class _Result { - +List~str~ tokens - +List~str~ results - +List~bool~ done_flags + +List[str] tokens + +List[str] results + +List[bool] done_flags +append(token, idx) - +get_results() List~str~ + +get_results() List[str] } class ChatMessage { @@ -485,13 +529,13 @@ classDiagram } class ChatCompletionRequest { - +List~ChatMessage~ messages + +List[ChatMessage] messages +float temperature +float top_p +int top_k +int max_tokens +bool stream - +Optional~str~ system_prompt + +Optional[str] system_prompt } class CompletionResponse { @@ -499,7 +543,7 @@ classDiagram +str object +int created +str model - +List~Dict~ choices + +List[Dict] choices } } @@ -542,7 +586,6 @@ classDiagram TrainContext --> Checkpoint : manages TrainContext --> BaseStrategy : uses TrainContext --> BaseScheduler : uses - AutoModel --> ModelConfig : contains SchedulerFactory ..> BaseScheduler : creates BaseScheduler <|-- CosineScheduler BaseScheduler <|-- SGDRScheduler @@ -553,11 +596,19 @@ classDiagram TrainCallback <|-- ProgressBarCallback TrainCallback <|-- MetricLoggerCallback InferenceEngine --> InferenceScheduler : uses + InferenceEngine --> GenerationRequest : uses + GenerationRequest --> GenerationParams : contains InferenceScheduler --> Task : manages + Task --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses + InferenceScheduler --> SlotAllocator : uses InferenceScheduler --> Transformer : uses InferenceEngine --> Transformer : uses - InferenceEngine --> GenerationRequest : uses + InferenceEngine --> _Result : uses + BaseSamplingStrategy <|-- TemperatureStrategy + BaseSamplingStrategy <|-- TopKStrategy + BaseSamplingStrategy <|-- TopPStrategy + SamplingPipeline --> BaseSamplingStrategy : composes Server --> InferenceEngine : uses Server --> ChatMessage : uses Server --> ChatCompletionRequest : uses @@ -585,7 +636,7 @@ classDiagram ParallelModel <|-- ColumnParallelLinear AutoTokenizer --> ChatTemplate : uses InferenceScheduler --> PrefixCacheManager : uses - InferenceScheduler --> RadixNode : uses + PrefixCacheManager --> _RadixNode : composes Checkpoint ..> Checkpoint : saves/loads TrainConfig --> DatasetFactory : selects TrainConfig --> SchedulerFactory : selects @@ -606,7 +657,7 @@ classDiagram | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | -| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching | +| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.factory** | Registry, BaseFactory | Generic component registration | @@ -620,6 +671,8 @@ classDiagram | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Singleton** | `TrainContext` | Training process global state management | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | +| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask | +| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | @@ -630,7 +683,7 @@ classDiagram 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` -4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming +4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors diff --git a/assets/docs/params.md b/assets/docs/params.md index 0ad17c8..26e772a 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -4,70 +4,83 @@ ### Basic Parameters -| Parameter | Description | Default Value | -|-----------|-------------|---------------| -| `--train_type` | Training type (seq, sft, dpo, grpo) | required | -| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer | +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required | | `--data_root_path` | Dataset root directory | required | | `--param_path` | Model parameters or checkpoint path | required | | `--n_epoch` | Total training epochs | 1 | -| `--batch_size` | Batch size | 4 | -| `--accumulation_steps` | Gradient accumulation steps | 1 | +| `--batch_size` | Batch size | 1 | +| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 | ### Learning Rate Scheduling -| Parameter | Description | Default Value | -|-----------|-------------|---------------| +| Parameter | Description | Default | +|-----------|-------------|---------| | `--warmup_steps` | Warmup steps | 1000 | -| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 | -| `--max_grad_norm` | Maximum gradient norm | 1.0 | +| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 | +| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 | -### Checkpoint +### Optimizer (AdamW) -| Parameter | Description | Default Value | -|-----------|-------------|---------------| -| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 | -| `--ckpt_dir` | Checkpoint save directory | checkpoint | -| `--resume_dir` | Resume training from specified path | - | - -### Optimizer Parameters - -| Parameter | Description | Default Value | -|-----------|-------------|---------------| +| Parameter | Description | Default | +|-----------|-------------|---------| | `--adamw_beta1` | AdamW beta1 | 0.9 | | `--adamw_beta2` | AdamW beta2 | 0.95 | | `--adamw_weight_decay` | AdamW weight decay | 0.01 | ### Data Loading -| Parameter | Description | Default Value | -|-----------|-------------|---------------| -| `--random_seed` | Random seed | 3407 | -| `--num_workers` | DataLoader workers | 0 | -| `--prefetch_factor` | Prefetch factor for dataloader | None | -| `--pin_memory` | Enable pin_memory | False | -| `--no_pin_memory` | Disable pin_memory | - | +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--window_size` | Max input sequence length | model config `max_len` | +| `--stride` | Stride for sliding window over sequences | None | +| `--random_seed` | Random seed for reproducibility | 3407 | +| `--num_workers` | DataLoader worker processes | 4 | +| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) | + +### Checkpoint & Resume + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--ckpt_interval` | Iterations between checkpoints | 5000 | +| `--ckpt_dir` | Checkpoint save directory | checkpoint | +| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 | +| `--start_batch` | Resume from batch iteration | 0 | ### Distributed Training -| Parameter | Description | Default Value | -|-----------|-------------|---------------| -| `--nprocs` | Number of GPUs | 1 | -| `--device_type` | Device type (cuda/cpu) | cuda | +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--nprocs` | Number of GPUs / processes | 1 | +| `--device_type` | Device type | cuda | -### Other Parameters +### Strategy-specific -| Parameter | Description | Default Value | -|-----------|-------------|---------------| -| `--window_size` | Maximum input sequence length | model config max_len | -| `--stride` | Input sequence stride | - | -| `--dpo_beta` | DPO beta value | 0.1 | -| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 | -| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 | -| `--grpo_group_size` | GRPO group size | 4 | -| `--label_smoothing` | Label smoothing parameter | 0.1 | -| `--start_epoch` | Starting epoch | 0 | -| `--start_batch` | Starting batch | 0 | +| Parameter | Description | Default | Used by | +|-----------|-------------|---------|---------| +| `--dpo_beta` | DPO beta value | 0.1 | `dpo` | +| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` | + +### Usage Example + +```bash +python scripts/tools/train.py \ + --train_type seq \ + --data_root_path /path/to/dataset \ + --param_path /path/to/model \ + --n_epoch 3 \ + --batch_size 4 \ + --accumulation_steps 8 \ + --max_lr 3e-4 \ + --warmup_steps 2000 \ + --max_grad_norm 1.0 \ + --ckpt_interval 5000 \ + --ckpt_dir ./checkpoints \ + --num_workers 4 \ + --nprocs 1 \ + --device_type cuda +``` --- @@ -89,14 +102,14 @@ ```python import torch from astrai.model import AutoModel -from astrai.tokenize import Tokenizer +from astrai.tokenize import AutoTokenizer from astrai.inference import InferenceEngine, GenerationRequest # Load model using AutoModel model = AutoModel.from_pretrained("your_model_dir") # Load tokenizer -tokenizer = Tokenizer("your_model_dir") +tokenizer = AutoTokenizer.from_pretrained("your_model_dir") # Create engine with separate model and tokenizer engine = InferenceEngine(