docs: 修正文档错误并补充训练参数说明
- README: 补充训练参数速查表,完善训练命令示例 - design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等 新增类),修正参数名和类型错误,统一泛型符号 - params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数 (grpo_*、model_type、resume_dir),补充完整示例 - dataflow.md: _RadixNode 命名修正
This commit is contained in:
parent
44d7a4e959
commit
78dc2bd41c
33
README.md
33
README.md
|
|
@ -27,9 +27,6 @@
|
|||
|
||||
## 📖 Table of Contents
|
||||
|
||||
<details open>
|
||||
<summary><b>English</b></summary>
|
||||
|
||||
- [Features](#features)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Documentation](#documentation)
|
||||
|
|
@ -37,8 +34,6 @@
|
|||
- [Community](#community)
|
||||
- [License](#license)
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
<a id="english"></a>
|
||||
|
|
@ -75,7 +70,14 @@ pip install -e ".[dev]"
|
|||
python scripts/tools/train.py \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/param_path
|
||||
--param_path=/path/to/model \
|
||||
--n_epoch=3 \
|
||||
--batch_size=4 \
|
||||
--accumulation_steps=8 \
|
||||
--max_lr=3e-4 \
|
||||
--warmup_steps=2000 \
|
||||
--ckpt_interval=5000 \
|
||||
--ckpt_dir=./checkpoints
|
||||
```
|
||||
|
||||
#### Generate Text
|
||||
|
|
@ -84,6 +86,25 @@ python scripts/tools/train.py \
|
|||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||
```
|
||||
|
||||
#### Training Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||
| `--data_root_path` | Dataset root directory | required |
|
||||
| `--param_path` | Model / checkpoint path | required |
|
||||
| `--n_epoch` | Training epochs | 1 |
|
||||
| `--batch_size` | Batch size | 1 |
|
||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
||||
| `--warmup_steps` | LR warmup steps | 1000 |
|
||||
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
||||
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
||||
| `--num_workers` | DataLoader workers | 4 |
|
||||
| `--nprocs` | Number of GPUs | 1 |
|
||||
|
||||
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
||||
|
||||
#### Docker
|
||||
|
||||
Build and run with Docker (recommended for GPU environments):
|
||||
|
|
|
|||
|
|
@ -76,7 +76,14 @@ pip install -e ".[dev]"
|
|||
python scripts/tools/train.py \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/param_path
|
||||
--param_path=/path/to/model \
|
||||
--n_epoch=3 \
|
||||
--batch_size=4 \
|
||||
--accumulation_steps=8 \
|
||||
--max_lr=3e-4 \
|
||||
--warmup_steps=2000 \
|
||||
--ckpt_interval=5000 \
|
||||
--ckpt_dir=./checkpoints
|
||||
```
|
||||
|
||||
#### 文本生成
|
||||
|
|
@ -85,6 +92,25 @@ python scripts/tools/train.py \
|
|||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||
```
|
||||
|
||||
#### 训练参数
|
||||
|
||||
| 参数 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 |
|
||||
| `--data_root_path` | 数据集根目录 | 必填 |
|
||||
| `--param_path` | 模型参数或断点路径 | 必填 |
|
||||
| `--n_epoch` | 训练轮数 | 1 |
|
||||
| `--batch_size` | 批次大小 | 1 |
|
||||
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
||||
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
||||
| `--warmup_steps` | 预热步数 | 1000 |
|
||||
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
||||
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
||||
| `--num_workers` | 数据加载线程数 | 4 |
|
||||
| `--nprocs` | GPU 数量 | 1 |
|
||||
|
||||
完整参数列表见[参数说明](./params.md#training-parameters)。
|
||||
|
||||
#### Docker
|
||||
|
||||
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ flowchart LR
|
|||
- **`TaskStatus`**: Task state enumeration
|
||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
||||
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
|
||||
- **`RadixNode`**: Tree node structure for prefix caching
|
||||
- **`_RadixNode`**: Tree node structure for prefix caching
|
||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
||||
|
||||
#### 6.3 Server (`server.py`)
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class BaseSegmentFetcher {
|
||||
+List~Tensor~ segments
|
||||
+List~int~ cum_lengths
|
||||
+List[Tensor] segments
|
||||
+List[int] cum_lengths
|
||||
+int total_length
|
||||
+fetch_data(begin_idx, end_idx) Tensor
|
||||
}
|
||||
|
|
@ -191,7 +191,7 @@ classDiagram
|
|||
+int dim
|
||||
+int max_len
|
||||
+float base
|
||||
+forward(x, start_pos) Tuple~Tensor, Tensor~
|
||||
+forward(x, start_pos) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class Embedding {
|
||||
|
|
@ -202,14 +202,14 @@ classDiagram
|
|||
|
||||
namespace tokenize {
|
||||
class AutoTokenizer {
|
||||
+List~str~ stop_ids
|
||||
+List[str] stop_ids
|
||||
+int bos_id
|
||||
+int eos_id
|
||||
+int pad_id
|
||||
+vocab_size int
|
||||
+encode(tokens, out_ids, add_special_tokens) List~int~
|
||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
+apply_chat_template(messages, tokenize) Union~str, List[int]~
|
||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||
+set_chat_template(template)
|
||||
+load(path)
|
||||
+from_pretrained(path) AutoTokenizer
|
||||
|
|
@ -228,7 +228,7 @@ classDiagram
|
|||
+Dict _entries
|
||||
+register(name, component_cls, category, priority)
|
||||
+get(name) Type
|
||||
+list_names() List~str~
|
||||
+list_names() List[str]
|
||||
}
|
||||
|
||||
class BaseFactory {
|
||||
|
|
@ -242,10 +242,10 @@ classDiagram
|
|||
namespace trainer {
|
||||
class Trainer {
|
||||
+TrainConfig train_config
|
||||
+List~TrainCallback~ callbacks
|
||||
+List[TrainCallback] callbacks
|
||||
+train(checkpoint)
|
||||
+_build_context(checkpoint) TrainContext
|
||||
+_get_default_callbacks() List~TrainCallback~
|
||||
+_get_default_callbacks() List[TrainCallback]
|
||||
}
|
||||
|
||||
class TrainContext {
|
||||
|
|
@ -308,7 +308,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class BaseScheduler {
|
||||
+get_lr() List~float~
|
||||
+get_lr() List[float]
|
||||
+step()
|
||||
}
|
||||
|
||||
|
|
@ -390,10 +390,8 @@ classDiagram
|
|||
+InferenceScheduler scheduler
|
||||
+int max_batch_size
|
||||
+Optional int max_seq_len
|
||||
+int max_prefix_len
|
||||
+int max_prompt_len
|
||||
+int cache_capacity
|
||||
+Tensor kv_cache
|
||||
+Tensor seq_mask
|
||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||
+get_stats() Dict
|
||||
|
|
@ -403,10 +401,10 @@ classDiagram
|
|||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+ModelConfig config
|
||||
+Tuple kv_cache
|
||||
+Tensor seq_mask
|
||||
+PrefixCacheManager prefix_cache
|
||||
+SlotAllocator slot_allocator
|
||||
+List waiting_queue
|
||||
+List active_tasks
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
|
|
@ -417,21 +415,22 @@ classDiagram
|
|||
}
|
||||
|
||||
class PrefixCacheManager {
|
||||
+RadixNode root
|
||||
+_RadixNode root
|
||||
+int max_capacity
|
||||
+List lru
|
||||
+insert(token_ids, slot)
|
||||
+find_longest_prefix(token_ids) Tuple[int, int]
|
||||
+OrderedDict _lru
|
||||
+insert(token_ids, slot, slot_ver)
|
||||
+find(token_ids) Tuple[int, int, int]
|
||||
+pin(token_ids)
|
||||
+release(token_ids)
|
||||
+copy_kv(token_ids, target_slot, kv_cache, n_layers)
|
||||
}
|
||||
|
||||
class RadixNode {
|
||||
class _RadixNode {
|
||||
+Dict children
|
||||
+int hash
|
||||
+int slot
|
||||
+int slot_ver
|
||||
+int ref_count
|
||||
+float last_access
|
||||
+List token_sequence
|
||||
}
|
||||
|
||||
class Task {
|
||||
|
|
@ -446,15 +445,69 @@ classDiagram
|
|||
+int input_tokens
|
||||
+int output_tokens
|
||||
+int slot
|
||||
+int prefix_len
|
||||
+float arrival_time
|
||||
+float finish_time
|
||||
+Callable stream_callback
|
||||
+is_finished(stop_ids) bool
|
||||
}
|
||||
|
||||
class TaskStatus {
|
||||
+str PENDING
|
||||
+str RUNNING
|
||||
+str FINISHED
|
||||
+str ABORTED
|
||||
<<enumeration>>
|
||||
PENDING
|
||||
RUNNING
|
||||
FINISHED
|
||||
ABORTED
|
||||
}
|
||||
|
||||
class GenerationRequest {
|
||||
+List[Dict] messages
|
||||
+GenerationParams params
|
||||
+bool stream
|
||||
}
|
||||
|
||||
class GenerationParams {
|
||||
<<value object>>
|
||||
+int top_k
|
||||
+float top_p
|
||||
+float temperature
|
||||
+int max_tokens
|
||||
}
|
||||
|
||||
class SlotAllocator {
|
||||
+int _max_slots
|
||||
+int _free_mask
|
||||
+List _versions
|
||||
+alloc() int
|
||||
+free(idx)
|
||||
+occupy(idx)
|
||||
+is_free(idx) bool
|
||||
+version(idx) int
|
||||
}
|
||||
|
||||
class BaseSamplingStrategy {
|
||||
<<abstract>>
|
||||
+apply(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class TemperatureStrategy {
|
||||
+float temperature
|
||||
+apply(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class TopKStrategy {
|
||||
+int top_k
|
||||
+apply(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class TopPStrategy {
|
||||
+float top_p
|
||||
+apply(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class SamplingPipeline {
|
||||
+List strategies
|
||||
+apply(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class Server {
|
||||
|
|
@ -462,21 +515,12 @@ classDiagram
|
|||
+predict(request)
|
||||
}
|
||||
|
||||
class GenerationRequest {
|
||||
+int top_k
|
||||
+float top_p
|
||||
+float temperature
|
||||
+int max_len
|
||||
+List~Dict~ messages
|
||||
+stream bool
|
||||
}
|
||||
|
||||
class _Result {
|
||||
+List~str~ tokens
|
||||
+List~str~ results
|
||||
+List~bool~ done_flags
|
||||
+List[str] tokens
|
||||
+List[str] results
|
||||
+List[bool] done_flags
|
||||
+append(token, idx)
|
||||
+get_results() List~str~
|
||||
+get_results() List[str]
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
|
|
@ -485,13 +529,13 @@ classDiagram
|
|||
}
|
||||
|
||||
class ChatCompletionRequest {
|
||||
+List~ChatMessage~ messages
|
||||
+List[ChatMessage] messages
|
||||
+float temperature
|
||||
+float top_p
|
||||
+int top_k
|
||||
+int max_tokens
|
||||
+bool stream
|
||||
+Optional~str~ system_prompt
|
||||
+Optional[str] system_prompt
|
||||
}
|
||||
|
||||
class CompletionResponse {
|
||||
|
|
@ -499,7 +543,7 @@ classDiagram
|
|||
+str object
|
||||
+int created
|
||||
+str model
|
||||
+List~Dict~ choices
|
||||
+List[Dict] choices
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -542,7 +586,6 @@ classDiagram
|
|||
TrainContext --> Checkpoint : manages
|
||||
TrainContext --> BaseStrategy : uses
|
||||
TrainContext --> BaseScheduler : uses
|
||||
AutoModel --> ModelConfig : contains
|
||||
SchedulerFactory ..> BaseScheduler : creates
|
||||
BaseScheduler <|-- CosineScheduler
|
||||
BaseScheduler <|-- SGDRScheduler
|
||||
|
|
@ -553,11 +596,19 @@ classDiagram
|
|||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
InferenceEngine --> InferenceScheduler : uses
|
||||
InferenceEngine --> GenerationRequest : uses
|
||||
GenerationRequest --> GenerationParams : contains
|
||||
InferenceScheduler --> Task : manages
|
||||
Task --> TaskStatus : uses
|
||||
InferenceScheduler --> TaskStatus : uses
|
||||
InferenceScheduler --> SlotAllocator : uses
|
||||
InferenceScheduler --> Transformer : uses
|
||||
InferenceEngine --> Transformer : uses
|
||||
InferenceEngine --> GenerationRequest : uses
|
||||
InferenceEngine --> _Result : uses
|
||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||
BaseSamplingStrategy <|-- TopKStrategy
|
||||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
SamplingPipeline --> BaseSamplingStrategy : composes
|
||||
Server --> InferenceEngine : uses
|
||||
Server --> ChatMessage : uses
|
||||
Server --> ChatCompletionRequest : uses
|
||||
|
|
@ -585,7 +636,7 @@ classDiagram
|
|||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoTokenizer --> ChatTemplate : uses
|
||||
InferenceScheduler --> PrefixCacheManager : uses
|
||||
InferenceScheduler --> RadixNode : uses
|
||||
PrefixCacheManager --> _RadixNode : composes
|
||||
Checkpoint ..> Checkpoint : saves/loads
|
||||
TrainConfig --> DatasetFactory : selects
|
||||
TrainConfig --> SchedulerFactory : selects
|
||||
|
|
@ -606,7 +657,7 @@ classDiagram
|
|||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||
|
||||
|
|
@ -620,6 +671,8 @@ classDiagram
|
|||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Singleton** | `TrainContext` | Training process global state management |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
|
||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||
|
|
@ -630,7 +683,7 @@ classDiagram
|
|||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming
|
||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||
|
|
|
|||
|
|
@ -4,70 +4,83 @@
|
|||
|
||||
### Basic Parameters
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
|
||||
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||
| `--data_root_path` | Dataset root directory | required |
|
||||
| `--param_path` | Model parameters or checkpoint path | required |
|
||||
| `--n_epoch` | Total training epochs | 1 |
|
||||
| `--batch_size` | Batch size | 4 |
|
||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||
| `--batch_size` | Batch size | 1 |
|
||||
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||
|
||||
### Learning Rate Scheduling
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--warmup_steps` | Warmup steps | 1000 |
|
||||
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
|
||||
| `--max_grad_norm` | Maximum gradient norm | 1.0 |
|
||||
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||
|
||||
### Checkpoint
|
||||
### Optimizer (AdamW)
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
|
||||
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
||||
| `--resume_dir` | Resume training from specified path | - |
|
||||
|
||||
### Optimizer Parameters
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||
|
||||
### Data Loading
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `--random_seed` | Random seed | 3407 |
|
||||
| `--num_workers` | DataLoader workers | 0 |
|
||||
| `--prefetch_factor` | Prefetch factor for dataloader | None |
|
||||
| `--pin_memory` | Enable pin_memory | False |
|
||||
| `--no_pin_memory` | Disable pin_memory | - |
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--window_size` | Max input sequence length | model config `max_len` |
|
||||
| `--stride` | Stride for sliding window over sequences | None |
|
||||
| `--random_seed` | Random seed for reproducibility | 3407 |
|
||||
| `--num_workers` | DataLoader worker processes | 4 |
|
||||
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
|
||||
|
||||
### Checkpoint & Resume
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
|
||||
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
||||
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
|
||||
| `--start_batch` | Resume from batch iteration | 0 |
|
||||
|
||||
### Distributed Training
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `--nprocs` | Number of GPUs | 1 |
|
||||
| `--device_type` | Device type (cuda/cpu) | cuda |
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||
| `--device_type` | Device type | cuda |
|
||||
|
||||
### Other Parameters
|
||||
### Strategy-specific
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `--window_size` | Maximum input sequence length | model config max_len |
|
||||
| `--stride` | Input sequence stride | - |
|
||||
| `--dpo_beta` | DPO beta value | 0.1 |
|
||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
||||
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
|
||||
| `--grpo_group_size` | GRPO group size | 4 |
|
||||
| `--label_smoothing` | Label smoothing parameter | 0.1 |
|
||||
| `--start_epoch` | Starting epoch | 0 |
|
||||
| `--start_batch` | Starting batch | 0 |
|
||||
| Parameter | Description | Default | Used by |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
||||
|
||||
### Usage Example
|
||||
|
||||
```bash
|
||||
python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--n_epoch 3 \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 2000 \
|
||||
--max_grad_norm 1.0 \
|
||||
--ckpt_interval 5000 \
|
||||
--ckpt_dir ./checkpoints \
|
||||
--num_workers 4 \
|
||||
--nprocs 1 \
|
||||
--device_type cuda
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -89,14 +102,14 @@
|
|||
```python
|
||||
import torch
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import Tokenizer
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.inference import InferenceEngine, GenerationRequest
|
||||
|
||||
# Load model using AutoModel
|
||||
model = AutoModel.from_pretrained("your_model_dir")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = Tokenizer("your_model_dir")
|
||||
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
||||
|
||||
# Create engine with separate model and tokenizer
|
||||
engine = InferenceEngine(
|
||||
|
|
|
|||
Loading…
Reference in New Issue