docs: 修正文档错误并补充训练参数说明

- README: 补充训练参数速查表,完善训练命令示例
- design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等
  新增类),修正参数名和类型错误,统一泛型符号
- params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数
  (grpo_*、model_type、resume_dir),补充完整示例
- dataflow.md: _RadixNode 命名修正
This commit is contained in:
ViperEkura 2026-05-08 18:07:57 +08:00
parent 44d7a4e959
commit 78dc2bd41c
5 changed files with 213 additions and 100 deletions

View File

@ -27,9 +27,6 @@
## 📖 Table of Contents ## 📖 Table of Contents
<details open>
<summary><b>English</b></summary>
- [Features](#features) - [Features](#features)
- [Quick Start](#quick-start) - [Quick Start](#quick-start)
- [Documentation](#documentation) - [Documentation](#documentation)
@ -37,8 +34,6 @@
- [Community](#community) - [Community](#community)
- [License](#license) - [License](#license)
</details>
--- ---
<a id="english"></a> <a id="english"></a>
@ -75,7 +70,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \ python scripts/tools/train.py \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/param_path --param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
#### Generate Text #### Generate Text
@ -84,6 +86,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### Training Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--num_workers` | DataLoader workers | 4 |
| `--nprocs` | Number of GPUs | 1 |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Docker #### Docker
Build and run with Docker (recommended for GPU environments): Build and run with Docker (recommended for GPU environments):

View File

@ -76,7 +76,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \ python scripts/tools/train.py \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/param_path --param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
#### 文本生成 #### 文本生成
@ -85,6 +92,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### 训练参数
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--num_workers` | 数据加载线程数 | 4 |
| `--nprocs` | GPU 数量 | 1 |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### Docker #### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境): 使用 Docker 构建和运行(推荐用于 GPU 环境):

View File

@ -176,7 +176,7 @@ flowchart LR
- **`TaskStatus`**: Task state enumeration - **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits - **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse - **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`RadixNode`**: Tree node structure for prefix caching - **`_RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately - Continuous batching: new requests can join at any time, completed requests are released immediately
#### 6.3 Server (`server.py`) #### 6.3 Server (`server.py`)

View File

@ -85,8 +85,8 @@ classDiagram
} }
class BaseSegmentFetcher { class BaseSegmentFetcher {
+List~Tensor~ segments +List[Tensor] segments
+List~int~ cum_lengths +List[int] cum_lengths
+int total_length +int total_length
+fetch_data(begin_idx, end_idx) Tensor +fetch_data(begin_idx, end_idx) Tensor
} }
@ -191,7 +191,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, start_pos) Tuple~Tensor, Tensor~ +forward(x, start_pos) Tuple[Tensor, Tensor]
} }
class Embedding { class Embedding {
@ -202,14 +202,14 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+List~str~ stop_ids +List[str] stop_ids
+int bos_id +int bos_id
+int eos_id +int eos_id
+int pad_id +int pad_id
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List~int~ +encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union~str, List[int]~ +apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
+from_pretrained(path) AutoTokenizer +from_pretrained(path) AutoTokenizer
@ -228,7 +228,7 @@ classDiagram
+Dict _entries +Dict _entries
+register(name, component_cls, category, priority) +register(name, component_cls, category, priority)
+get(name) Type +get(name) Type
+list_names() List~str~ +list_names() List[str]
} }
class BaseFactory { class BaseFactory {
@ -242,10 +242,10 @@ classDiagram
namespace trainer { namespace trainer {
class Trainer { class Trainer {
+TrainConfig train_config +TrainConfig train_config
+List~TrainCallback~ callbacks +List[TrainCallback] callbacks
+train(checkpoint) +train(checkpoint)
+_build_context(checkpoint) TrainContext +_build_context(checkpoint) TrainContext
+_get_default_callbacks() List~TrainCallback~ +_get_default_callbacks() List[TrainCallback]
} }
class TrainContext { class TrainContext {
@ -308,7 +308,7 @@ classDiagram
} }
class BaseScheduler { class BaseScheduler {
+get_lr() List~float~ +get_lr() List[float]
+step() +step()
} }
@ -390,10 +390,8 @@ classDiagram
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size +int max_batch_size
+Optional int max_seq_len +Optional int max_seq_len
+int max_prefix_len +int max_prompt_len
+int cache_capacity +int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict +get_stats() Dict
@ -403,10 +401,10 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+ModelConfig config
+Tuple kv_cache +Tuple kv_cache
+Tensor seq_mask +Tensor seq_mask
+PrefixCacheManager prefix_cache +PrefixCacheManager prefix_cache
+SlotAllocator slot_allocator
+List waiting_queue +List waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -417,21 +415,22 @@ classDiagram
} }
class PrefixCacheManager { class PrefixCacheManager {
+RadixNode root +_RadixNode root
+int max_capacity +int max_capacity
+List lru +OrderedDict _lru
+insert(token_ids, slot) +insert(token_ids, slot, slot_ver)
+find_longest_prefix(token_ids) Tuple[int, int] +find(token_ids) Tuple[int, int, int]
+pin(token_ids)
+release(token_ids) +release(token_ids)
+copy_kv(token_ids, target_slot, kv_cache, n_layers)
} }
class RadixNode { class _RadixNode {
+Dict children +Dict children
+int hash
+int slot +int slot
+int slot_ver
+int ref_count +int ref_count
+float last_access +float last_access
+List token_sequence
} }
class Task { class Task {
@ -446,15 +445,69 @@ classDiagram
+int input_tokens +int input_tokens
+int output_tokens +int output_tokens
+int slot +int slot
+int prefix_len
+float arrival_time
+float finish_time
+Callable stream_callback +Callable stream_callback
+is_finished(stop_ids) bool +is_finished(stop_ids) bool
} }
class TaskStatus { class TaskStatus {
+str PENDING <<enumeration>>
+str RUNNING PENDING
+str FINISHED RUNNING
+str ABORTED FINISHED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
}
class SlotAllocator {
+int _max_slots
+int _free_mask
+List _versions
+alloc() int
+free(idx)
+occupy(idx)
+is_free(idx) bool
+version(idx) int
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
} }
class Server { class Server {
@ -462,21 +515,12 @@ classDiagram
+predict(request) +predict(request)
} }
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+List~Dict~ messages
+stream bool
}
class _Result { class _Result {
+List~str~ tokens +List[str] tokens
+List~str~ results +List[str] results
+List~bool~ done_flags +List[bool] done_flags
+append(token, idx) +append(token, idx)
+get_results() List~str~ +get_results() List[str]
} }
class ChatMessage { class ChatMessage {
@ -485,13 +529,13 @@ classDiagram
} }
class ChatCompletionRequest { class ChatCompletionRequest {
+List~ChatMessage~ messages +List[ChatMessage] messages
+float temperature +float temperature
+float top_p +float top_p
+int top_k +int top_k
+int max_tokens +int max_tokens
+bool stream +bool stream
+Optional~str~ system_prompt +Optional[str] system_prompt
} }
class CompletionResponse { class CompletionResponse {
@ -499,7 +543,7 @@ classDiagram
+str object +str object
+int created +int created
+str model +str model
+List~Dict~ choices +List[Dict] choices
} }
} }
@ -542,7 +586,6 @@ classDiagram
TrainContext --> Checkpoint : manages TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses TrainContext --> BaseScheduler : uses
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
@ -553,11 +596,19 @@ classDiagram
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> SlotAllocator : uses
InferenceScheduler --> Transformer : uses InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses InferenceEngine --> _Result : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
Server --> InferenceEngine : uses Server --> InferenceEngine : uses
Server --> ChatMessage : uses Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses Server --> ChatCompletionRequest : uses
@ -585,7 +636,7 @@ classDiagram
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses PrefixCacheManager --> _RadixNode : composes
Checkpoint ..> Checkpoint : saves/loads Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects TrainConfig --> SchedulerFactory : selects
@ -606,7 +657,7 @@ classDiagram
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching | | **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration | | **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -620,6 +671,8 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management | | **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
@ -630,7 +683,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming 4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -4,70 +4,83 @@
### Basic Parameters ### Basic Parameters
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--train_type` | Training type (seq, sft, dpo, grpo) | required | | `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| `--data_root_path` | Dataset root directory | required | | `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required | | `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 | | `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 4 | | `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 | | `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling ### Learning Rate Scheduling
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 | | `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 | | `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm | 1.0 | | `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
### Checkpoint ### Optimizer (AdamW)
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--resume_dir` | Resume training from specified path | - |
### Optimizer Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--adamw_beta1` | AdamW beta1 | 0.9 | | `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 | | `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 | | `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading ### Data Loading
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--random_seed` | Random seed | 3407 | | `--window_size` | Max input sequence length | model config `max_len` |
| `--num_workers` | DataLoader workers | 0 | | `--stride` | Stride for sliding window over sequences | None |
| `--prefetch_factor` | Prefetch factor for dataloader | None | | `--random_seed` | Random seed for reproducibility | 3407 |
| `--pin_memory` | Enable pin_memory | False | | `--num_workers` | DataLoader worker processes | 4 |
| `--no_pin_memory` | Disable pin_memory | - | | `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training ### Distributed Training
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--nprocs` | Number of GPUs | 1 | | `--nprocs` | Number of GPUs / processes | 1 |
| `--device_type` | Device type (cuda/cpu) | cuda | | `--device_type` | Device type | cuda |
### Other Parameters ### Strategy-specific
| Parameter | Description | Default Value | | Parameter | Description | Default | Used by |
|-----------|-------------|---------------| |-----------|-------------|---------|---------|
| `--window_size` | Maximum input sequence length | model config max_len | | `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--stride` | Input sequence stride | - | | `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
| `--dpo_beta` | DPO beta value | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 | ### Usage Example
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
| `--grpo_group_size` | GRPO group size | 4 | ```bash
| `--label_smoothing` | Label smoothing parameter | 0.1 | python scripts/tools/train.py \
| `--start_epoch` | Starting epoch | 0 | --train_type seq \
| `--start_batch` | Starting batch | 0 | --data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
```
--- ---
@ -89,14 +102,14 @@
```python ```python
import torch import torch
from astrai.model import AutoModel from astrai.model import AutoModel
from astrai.tokenize import Tokenizer from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel # Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir") model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer # Load tokenizer
tokenizer = Tokenizer("your_model_dir") tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer # Create engine with separate model and tokenizer
engine = InferenceEngine( engine = InferenceEngine(