docs: 修正文档错误并补充训练参数说明
- README: 补充训练参数速查表,完善训练命令示例 - design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等 新增类),修正参数名和类型错误,统一泛型符号 - params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数 (grpo_*、model_type、resume_dir),补充完整示例 - dataflow.md: _RadixNode 命名修正
This commit is contained in:
parent
44d7a4e959
commit
78dc2bd41c
33
README.md
33
README.md
|
|
@ -27,9 +27,6 @@
|
||||||
|
|
||||||
## 📖 Table of Contents
|
## 📖 Table of Contents
|
||||||
|
|
||||||
<details open>
|
|
||||||
<summary><b>English</b></summary>
|
|
||||||
|
|
||||||
- [Features](#features)
|
- [Features](#features)
|
||||||
- [Quick Start](#quick-start)
|
- [Quick Start](#quick-start)
|
||||||
- [Documentation](#documentation)
|
- [Documentation](#documentation)
|
||||||
|
|
@ -37,8 +34,6 @@
|
||||||
- [Community](#community)
|
- [Community](#community)
|
||||||
- [License](#license)
|
- [License](#license)
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<a id="english"></a>
|
<a id="english"></a>
|
||||||
|
|
@ -75,7 +70,14 @@ pip install -e ".[dev]"
|
||||||
python scripts/tools/train.py \
|
python scripts/tools/train.py \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/param_path
|
--param_path=/path/to/model \
|
||||||
|
--n_epoch=3 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--accumulation_steps=8 \
|
||||||
|
--max_lr=3e-4 \
|
||||||
|
--warmup_steps=2000 \
|
||||||
|
--ckpt_interval=5000 \
|
||||||
|
--ckpt_dir=./checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Generate Text
|
#### Generate Text
|
||||||
|
|
@ -84,6 +86,25 @@ python scripts/tools/train.py \
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Training Parameters
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||||
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
|
| `--param_path` | Model / checkpoint path | required |
|
||||||
|
| `--n_epoch` | Training epochs | 1 |
|
||||||
|
| `--batch_size` | Batch size | 1 |
|
||||||
|
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||||
|
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
||||||
|
| `--warmup_steps` | LR warmup steps | 1000 |
|
||||||
|
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
||||||
|
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
||||||
|
| `--num_workers` | DataLoader workers | 4 |
|
||||||
|
| `--nprocs` | Number of GPUs | 1 |
|
||||||
|
|
||||||
|
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
Build and run with Docker (recommended for GPU environments):
|
Build and run with Docker (recommended for GPU environments):
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,14 @@ pip install -e ".[dev]"
|
||||||
python scripts/tools/train.py \
|
python scripts/tools/train.py \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/param_path
|
--param_path=/path/to/model \
|
||||||
|
--n_epoch=3 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--accumulation_steps=8 \
|
||||||
|
--max_lr=3e-4 \
|
||||||
|
--warmup_steps=2000 \
|
||||||
|
--ckpt_interval=5000 \
|
||||||
|
--ckpt_dir=./checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 文本生成
|
#### 文本生成
|
||||||
|
|
@ -85,6 +92,25 @@ python scripts/tools/train.py \
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 训练参数
|
||||||
|
|
||||||
|
| 参数 | 说明 | 默认值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 |
|
||||||
|
| `--data_root_path` | 数据集根目录 | 必填 |
|
||||||
|
| `--param_path` | 模型参数或断点路径 | 必填 |
|
||||||
|
| `--n_epoch` | 训练轮数 | 1 |
|
||||||
|
| `--batch_size` | 批次大小 | 1 |
|
||||||
|
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
||||||
|
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
||||||
|
| `--warmup_steps` | 预热步数 | 1000 |
|
||||||
|
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
||||||
|
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
||||||
|
| `--num_workers` | 数据加载线程数 | 4 |
|
||||||
|
| `--nprocs` | GPU 数量 | 1 |
|
||||||
|
|
||||||
|
完整参数列表见[参数说明](./params.md#training-parameters)。
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
||||||
|
|
|
||||||
|
|
@ -176,7 +176,7 @@ flowchart LR
|
||||||
- **`TaskStatus`**: Task state enumeration
|
- **`TaskStatus`**: Task state enumeration
|
||||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
||||||
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
|
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
|
||||||
- **`RadixNode`**: Tree node structure for prefix caching
|
- **`_RadixNode`**: Tree node structure for prefix caching
|
||||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
||||||
|
|
||||||
#### 6.3 Server (`server.py`)
|
#### 6.3 Server (`server.py`)
|
||||||
|
|
|
||||||
|
|
@ -85,8 +85,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseSegmentFetcher {
|
class BaseSegmentFetcher {
|
||||||
+List~Tensor~ segments
|
+List[Tensor] segments
|
||||||
+List~int~ cum_lengths
|
+List[int] cum_lengths
|
||||||
+int total_length
|
+int total_length
|
||||||
+fetch_data(begin_idx, end_idx) Tensor
|
+fetch_data(begin_idx, end_idx) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -191,7 +191,7 @@ classDiagram
|
||||||
+int dim
|
+int dim
|
||||||
+int max_len
|
+int max_len
|
||||||
+float base
|
+float base
|
||||||
+forward(x, start_pos) Tuple~Tensor, Tensor~
|
+forward(x, start_pos) Tuple[Tensor, Tensor]
|
||||||
}
|
}
|
||||||
|
|
||||||
class Embedding {
|
class Embedding {
|
||||||
|
|
@ -202,14 +202,14 @@ classDiagram
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+List~str~ stop_ids
|
+List[str] stop_ids
|
||||||
+int bos_id
|
+int bos_id
|
||||||
+int eos_id
|
+int eos_id
|
||||||
+int pad_id
|
+int pad_id
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, add_special_tokens) List~int~
|
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
+apply_chat_template(messages, tokenize) Union~str, List[int]~
|
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||||
+set_chat_template(template)
|
+set_chat_template(template)
|
||||||
+load(path)
|
+load(path)
|
||||||
+from_pretrained(path) AutoTokenizer
|
+from_pretrained(path) AutoTokenizer
|
||||||
|
|
@ -228,7 +228,7 @@ classDiagram
|
||||||
+Dict _entries
|
+Dict _entries
|
||||||
+register(name, component_cls, category, priority)
|
+register(name, component_cls, category, priority)
|
||||||
+get(name) Type
|
+get(name) Type
|
||||||
+list_names() List~str~
|
+list_names() List[str]
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseFactory {
|
class BaseFactory {
|
||||||
|
|
@ -242,10 +242,10 @@ classDiagram
|
||||||
namespace trainer {
|
namespace trainer {
|
||||||
class Trainer {
|
class Trainer {
|
||||||
+TrainConfig train_config
|
+TrainConfig train_config
|
||||||
+List~TrainCallback~ callbacks
|
+List[TrainCallback] callbacks
|
||||||
+train(checkpoint)
|
+train(checkpoint)
|
||||||
+_build_context(checkpoint) TrainContext
|
+_build_context(checkpoint) TrainContext
|
||||||
+_get_default_callbacks() List~TrainCallback~
|
+_get_default_callbacks() List[TrainCallback]
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContext {
|
class TrainContext {
|
||||||
|
|
@ -308,7 +308,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseScheduler {
|
class BaseScheduler {
|
||||||
+get_lr() List~float~
|
+get_lr() List[float]
|
||||||
+step()
|
+step()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -390,10 +390,8 @@ classDiagram
|
||||||
+InferenceScheduler scheduler
|
+InferenceScheduler scheduler
|
||||||
+int max_batch_size
|
+int max_batch_size
|
||||||
+Optional int max_seq_len
|
+Optional int max_seq_len
|
||||||
+int max_prefix_len
|
+int max_prompt_len
|
||||||
+int cache_capacity
|
+int cache_capacity
|
||||||
+Tensor kv_cache
|
|
||||||
+Tensor seq_mask
|
|
||||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||||
+get_stats() Dict
|
+get_stats() Dict
|
||||||
|
|
@ -403,10 +401,10 @@ classDiagram
|
||||||
class InferenceScheduler {
|
class InferenceScheduler {
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+ModelConfig config
|
|
||||||
+Tuple kv_cache
|
+Tuple kv_cache
|
||||||
+Tensor seq_mask
|
+Tensor seq_mask
|
||||||
+PrefixCacheManager prefix_cache
|
+PrefixCacheManager prefix_cache
|
||||||
|
+SlotAllocator slot_allocator
|
||||||
+List waiting_queue
|
+List waiting_queue
|
||||||
+List active_tasks
|
+List active_tasks
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
|
|
@ -417,21 +415,22 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class PrefixCacheManager {
|
class PrefixCacheManager {
|
||||||
+RadixNode root
|
+_RadixNode root
|
||||||
+int max_capacity
|
+int max_capacity
|
||||||
+List lru
|
+OrderedDict _lru
|
||||||
+insert(token_ids, slot)
|
+insert(token_ids, slot, slot_ver)
|
||||||
+find_longest_prefix(token_ids) Tuple[int, int]
|
+find(token_ids) Tuple[int, int, int]
|
||||||
|
+pin(token_ids)
|
||||||
+release(token_ids)
|
+release(token_ids)
|
||||||
|
+copy_kv(token_ids, target_slot, kv_cache, n_layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
class RadixNode {
|
class _RadixNode {
|
||||||
+Dict children
|
+Dict children
|
||||||
+int hash
|
|
||||||
+int slot
|
+int slot
|
||||||
|
+int slot_ver
|
||||||
+int ref_count
|
+int ref_count
|
||||||
+float last_access
|
+float last_access
|
||||||
+List token_sequence
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class Task {
|
class Task {
|
||||||
|
|
@ -446,15 +445,69 @@ classDiagram
|
||||||
+int input_tokens
|
+int input_tokens
|
||||||
+int output_tokens
|
+int output_tokens
|
||||||
+int slot
|
+int slot
|
||||||
|
+int prefix_len
|
||||||
|
+float arrival_time
|
||||||
|
+float finish_time
|
||||||
+Callable stream_callback
|
+Callable stream_callback
|
||||||
+is_finished(stop_ids) bool
|
+is_finished(stop_ids) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
class TaskStatus {
|
class TaskStatus {
|
||||||
+str PENDING
|
<<enumeration>>
|
||||||
+str RUNNING
|
PENDING
|
||||||
+str FINISHED
|
RUNNING
|
||||||
+str ABORTED
|
FINISHED
|
||||||
|
ABORTED
|
||||||
|
}
|
||||||
|
|
||||||
|
class GenerationRequest {
|
||||||
|
+List[Dict] messages
|
||||||
|
+GenerationParams params
|
||||||
|
+bool stream
|
||||||
|
}
|
||||||
|
|
||||||
|
class GenerationParams {
|
||||||
|
<<value object>>
|
||||||
|
+int top_k
|
||||||
|
+float top_p
|
||||||
|
+float temperature
|
||||||
|
+int max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
class SlotAllocator {
|
||||||
|
+int _max_slots
|
||||||
|
+int _free_mask
|
||||||
|
+List _versions
|
||||||
|
+alloc() int
|
||||||
|
+free(idx)
|
||||||
|
+occupy(idx)
|
||||||
|
+is_free(idx) bool
|
||||||
|
+version(idx) int
|
||||||
|
}
|
||||||
|
|
||||||
|
class BaseSamplingStrategy {
|
||||||
|
<<abstract>>
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class TemperatureStrategy {
|
||||||
|
+float temperature
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class TopKStrategy {
|
||||||
|
+int top_k
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class TopPStrategy {
|
||||||
|
+float top_p
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class SamplingPipeline {
|
||||||
|
+List strategies
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class Server {
|
class Server {
|
||||||
|
|
@ -462,21 +515,12 @@ classDiagram
|
||||||
+predict(request)
|
+predict(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenerationRequest {
|
|
||||||
+int top_k
|
|
||||||
+float top_p
|
|
||||||
+float temperature
|
|
||||||
+int max_len
|
|
||||||
+List~Dict~ messages
|
|
||||||
+stream bool
|
|
||||||
}
|
|
||||||
|
|
||||||
class _Result {
|
class _Result {
|
||||||
+List~str~ tokens
|
+List[str] tokens
|
||||||
+List~str~ results
|
+List[str] results
|
||||||
+List~bool~ done_flags
|
+List[bool] done_flags
|
||||||
+append(token, idx)
|
+append(token, idx)
|
||||||
+get_results() List~str~
|
+get_results() List[str]
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatMessage {
|
class ChatMessage {
|
||||||
|
|
@ -485,13 +529,13 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatCompletionRequest {
|
class ChatCompletionRequest {
|
||||||
+List~ChatMessage~ messages
|
+List[ChatMessage] messages
|
||||||
+float temperature
|
+float temperature
|
||||||
+float top_p
|
+float top_p
|
||||||
+int top_k
|
+int top_k
|
||||||
+int max_tokens
|
+int max_tokens
|
||||||
+bool stream
|
+bool stream
|
||||||
+Optional~str~ system_prompt
|
+Optional[str] system_prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
class CompletionResponse {
|
class CompletionResponse {
|
||||||
|
|
@ -499,7 +543,7 @@ classDiagram
|
||||||
+str object
|
+str object
|
||||||
+int created
|
+int created
|
||||||
+str model
|
+str model
|
||||||
+List~Dict~ choices
|
+List[Dict] choices
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -542,7 +586,6 @@ classDiagram
|
||||||
TrainContext --> Checkpoint : manages
|
TrainContext --> Checkpoint : manages
|
||||||
TrainContext --> BaseStrategy : uses
|
TrainContext --> BaseStrategy : uses
|
||||||
TrainContext --> BaseScheduler : uses
|
TrainContext --> BaseScheduler : uses
|
||||||
AutoModel --> ModelConfig : contains
|
|
||||||
SchedulerFactory ..> BaseScheduler : creates
|
SchedulerFactory ..> BaseScheduler : creates
|
||||||
BaseScheduler <|-- CosineScheduler
|
BaseScheduler <|-- CosineScheduler
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
|
|
@ -553,11 +596,19 @@ classDiagram
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
InferenceEngine --> InferenceScheduler : uses
|
InferenceEngine --> InferenceScheduler : uses
|
||||||
|
InferenceEngine --> GenerationRequest : uses
|
||||||
|
GenerationRequest --> GenerationParams : contains
|
||||||
InferenceScheduler --> Task : manages
|
InferenceScheduler --> Task : manages
|
||||||
|
Task --> TaskStatus : uses
|
||||||
InferenceScheduler --> TaskStatus : uses
|
InferenceScheduler --> TaskStatus : uses
|
||||||
|
InferenceScheduler --> SlotAllocator : uses
|
||||||
InferenceScheduler --> Transformer : uses
|
InferenceScheduler --> Transformer : uses
|
||||||
InferenceEngine --> Transformer : uses
|
InferenceEngine --> Transformer : uses
|
||||||
InferenceEngine --> GenerationRequest : uses
|
InferenceEngine --> _Result : uses
|
||||||
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
|
SamplingPipeline --> BaseSamplingStrategy : composes
|
||||||
Server --> InferenceEngine : uses
|
Server --> InferenceEngine : uses
|
||||||
Server --> ChatMessage : uses
|
Server --> ChatMessage : uses
|
||||||
Server --> ChatCompletionRequest : uses
|
Server --> ChatCompletionRequest : uses
|
||||||
|
|
@ -585,7 +636,7 @@ classDiagram
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoTokenizer --> ChatTemplate : uses
|
AutoTokenizer --> ChatTemplate : uses
|
||||||
InferenceScheduler --> PrefixCacheManager : uses
|
InferenceScheduler --> PrefixCacheManager : uses
|
||||||
InferenceScheduler --> RadixNode : uses
|
PrefixCacheManager --> _RadixNode : composes
|
||||||
Checkpoint ..> Checkpoint : saves/loads
|
Checkpoint ..> Checkpoint : saves/loads
|
||||||
TrainConfig --> DatasetFactory : selects
|
TrainConfig --> DatasetFactory : selects
|
||||||
TrainConfig --> SchedulerFactory : selects
|
TrainConfig --> SchedulerFactory : selects
|
||||||
|
|
@ -606,7 +657,7 @@ classDiagram
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
||||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||||
|
|
||||||
|
|
@ -620,6 +671,8 @@ classDiagram
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Singleton** | `TrainContext` | Training process global state management |
|
| **Singleton** | `TrainContext` | Training process global state management |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
|
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
|
||||||
|
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||||
|
|
@ -630,7 +683,7 @@ classDiagram
|
||||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming
|
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
|
|
|
||||||
|
|
@ -4,70 +4,83 @@
|
||||||
|
|
||||||
### Basic Parameters
|
### Basic Parameters
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||||
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
|
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
| `--batch_size` | Batch size | 4 |
|
| `--batch_size` | Batch size | 1 |
|
||||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
### Learning Rate Scheduling
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--warmup_steps` | Warmup steps | 1000 |
|
| `--warmup_steps` | Warmup steps | 1000 |
|
||||||
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
|
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||||
| `--max_grad_norm` | Maximum gradient norm | 1.0 |
|
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||||
|
|
||||||
### Checkpoint
|
### Optimizer (AdamW)
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
|
|
||||||
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
|
||||||
| `--resume_dir` | Resume training from specified path | - |
|
|
||||||
|
|
||||||
### Optimizer Parameters
|
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
|
||||||
|-----------|-------------|---------------|
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||||
|
|
||||||
### Data Loading
|
### Data Loading
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--random_seed` | Random seed | 3407 |
|
| `--window_size` | Max input sequence length | model config `max_len` |
|
||||||
| `--num_workers` | DataLoader workers | 0 |
|
| `--stride` | Stride for sliding window over sequences | None |
|
||||||
| `--prefetch_factor` | Prefetch factor for dataloader | None |
|
| `--random_seed` | Random seed for reproducibility | 3407 |
|
||||||
| `--pin_memory` | Enable pin_memory | False |
|
| `--num_workers` | DataLoader worker processes | 4 |
|
||||||
| `--no_pin_memory` | Disable pin_memory | - |
|
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
|
||||||
|
|
||||||
|
### Checkpoint & Resume
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
|
||||||
|
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
||||||
|
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
|
||||||
|
| `--start_batch` | Resume from batch iteration | 0 |
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--nprocs` | Number of GPUs | 1 |
|
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||||
| `--device_type` | Device type (cuda/cpu) | cuda |
|
| `--device_type` | Device type | cuda |
|
||||||
|
|
||||||
### Other Parameters
|
### Strategy-specific
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default | Used by |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|---------|
|
||||||
| `--window_size` | Maximum input sequence length | model config max_len |
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
| `--stride` | Input sequence stride | - |
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
### Usage Example
|
||||||
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
|
|
||||||
| `--grpo_group_size` | GRPO group size | 4 |
|
```bash
|
||||||
| `--label_smoothing` | Label smoothing parameter | 0.1 |
|
python scripts/tools/train.py \
|
||||||
| `--start_epoch` | Starting epoch | 0 |
|
--train_type seq \
|
||||||
| `--start_batch` | Starting batch | 0 |
|
--data_root_path /path/to/dataset \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--n_epoch 3 \
|
||||||
|
--batch_size 4 \
|
||||||
|
--accumulation_steps 8 \
|
||||||
|
--max_lr 3e-4 \
|
||||||
|
--warmup_steps 2000 \
|
||||||
|
--max_grad_norm 1.0 \
|
||||||
|
--ckpt_interval 5000 \
|
||||||
|
--ckpt_dir ./checkpoints \
|
||||||
|
--num_workers 4 \
|
||||||
|
--nprocs 1 \
|
||||||
|
--device_type cuda
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -89,14 +102,14 @@
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import Tokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
from astrai.inference import InferenceEngine, GenerationRequest
|
from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
|
|
||||||
# Load model using AutoModel
|
# Load model using AutoModel
|
||||||
model = AutoModel.from_pretrained("your_model_dir")
|
model = AutoModel.from_pretrained("your_model_dir")
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = Tokenizer("your_model_dir")
|
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
||||||
|
|
||||||
# Create engine with separate model and tokenizer
|
# Create engine with separate model and tokenizer
|
||||||
engine = InferenceEngine(
|
engine = InferenceEngine(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue