docs : 三轮深度验证修复文档与代码不一致
- architecture.md: 修正 unwrap_model 返回类型、Config Optional 标注、方法签名错误、类名错误 - training.md: 补充 on_error 回调、修正训练循环顺序、补全策略参数、model.safetensors - inference.md: 修正 GenerationRequest 参数顺序、async 语法、KVCache 描述、temperature 约束 - dataflow.md: 补充 Store.load/fetch 流程、修正可选参数默认值 - README/params: 多 GPU 示例补全 --parallel_mode、文档表补充 preprocessing.md - preprocessing.md: Chat 模式算法补全 BOS token 步骤
This commit is contained in:
parent
31ae2deeba
commit
1c2ff05a6d
|
|
@ -82,6 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
@ -108,8 +109,8 @@ Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py \
|
python scripts/tools/generate.py \
|
||||||
--param_path /path/to/model \
|
--param_path /path/to/model \
|
||||||
--input_json_file /path/to/input.json \
|
--input_json_file /path/to/input.jsonl \
|
||||||
--output_json_file /path/to/output.json
|
--output_json_file /path/to/output.jsonl
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
@ -224,6 +225,7 @@ Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6y
|
||||||
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
||||||
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
||||||
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
||||||
|
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
|
||||||
|
|
||||||
### Contributing
|
### Contributing
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
@ -114,8 +115,8 @@ nohup python scripts/tools/train.py \
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py \
|
python scripts/tools/generate.py \
|
||||||
--param_path /path/to/model \
|
--param_path /path/to/model \
|
||||||
--input_json_file /path/to/input.json \
|
--input_json_file /path/to/input.jsonl \
|
||||||
--output_json_file /path/to/output.json
|
--output_json_file /path/to/output.jsonl
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
@ -230,6 +231,7 @@ python scripts/demo/generate_ar.py
|
||||||
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
||||||
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
||||||
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
||||||
|
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
|
||||||
|
|
||||||
### 贡献
|
### 贡献
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ classDiagram
|
||||||
class BaseConfig {
|
class BaseConfig {
|
||||||
+to_dict() Dict
|
+to_dict() Dict
|
||||||
+from_dict(d) Self
|
+from_dict(d) Self
|
||||||
|
+from_json(path) Self
|
||||||
|
+to_json(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseModelConfig {
|
class BaseModelConfig {
|
||||||
|
|
@ -17,42 +19,42 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class AutoRegressiveLMConfig {
|
class AutoRegressiveLMConfig {
|
||||||
+int vocab_size
|
+Optional[int] vocab_size
|
||||||
+int dim
|
+Optional[int] dim
|
||||||
+int n_layers
|
+Optional[int] n_layers
|
||||||
+float norm_eps
|
+Optional[float] norm_eps
|
||||||
+int dim_ffn
|
+Optional[int] dim_ffn
|
||||||
+Optional[bool] tie_weight
|
+Optional[bool] tie_weight
|
||||||
+Optional[dict] rope_scaling
|
+Optional[dict] rope_scaling
|
||||||
+int max_len
|
+Optional[int] max_len
|
||||||
+float rope_theta
|
+Optional[float] rope_theta
|
||||||
+str attn_type
|
+str attn_type
|
||||||
+int n_heads
|
+Optional[int] n_heads
|
||||||
+int n_kv_heads
|
+Optional[int] n_kv_heads
|
||||||
+bool use_qk_norm
|
+Optional[bool] use_qk_norm
|
||||||
+bool use_gated_attention
|
+Optional[bool] use_gated_attention
|
||||||
+Optional[int] kv_lora_rank
|
+Optional[int] kv_lora_rank
|
||||||
+Optional[int] qk_nope_head_dim
|
+Optional[int] qk_nope_head_dim
|
||||||
+Optional[int] qk_rope_head_dim
|
+Optional[int] qk_rope_head_dim
|
||||||
+str ffn_type
|
+str ffn_type
|
||||||
+int n_routed_experts
|
+Optional[int] n_routed_experts
|
||||||
+int n_shared_experts
|
+Optional[int] n_shared_experts
|
||||||
+int n_activated_experts
|
+Optional[int] n_activated_experts
|
||||||
+Optional[str] topk_method
|
+Optional[str] topk_method
|
||||||
}
|
}
|
||||||
|
|
||||||
class EncoderConfig {
|
class EncoderConfig {
|
||||||
+int vocab_size
|
+Optional[int] vocab_size
|
||||||
+int dim
|
+Optional[int] dim
|
||||||
+int n_layers
|
+Optional[int] n_layers
|
||||||
+float norm_eps
|
+Optional[float] norm_eps
|
||||||
+int dim_ffn
|
+Optional[int] dim_ffn
|
||||||
+int max_len
|
+Optional[int] max_len
|
||||||
+float rope_theta
|
+Optional[float] rope_theta
|
||||||
+int n_heads
|
+Optional[int] n_heads
|
||||||
+int n_kv_heads
|
+Optional[int] n_kv_heads
|
||||||
+bool use_qk_norm
|
+Optional[bool] use_qk_norm
|
||||||
+bool use_gated_attention
|
+Optional[bool] use_gated_attention
|
||||||
+Optional[dict] rope_scaling
|
+Optional[dict] rope_scaling
|
||||||
+Optional[str] pooling_type
|
+Optional[str] pooling_type
|
||||||
+Optional[bool] normalize_embeddings
|
+Optional[bool] normalize_embeddings
|
||||||
|
|
@ -64,6 +66,38 @@ classDiagram
|
||||||
+load(raw) BaseConfig
|
+load(raw) BaseConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class InputConfig {
|
||||||
|
+str type
|
||||||
|
+str messages_key
|
||||||
|
+str prompt_key
|
||||||
|
+str response_key
|
||||||
|
+str text_key
|
||||||
|
}
|
||||||
|
|
||||||
|
class ProcessingConfig {
|
||||||
|
+int max_seq_len
|
||||||
|
+int min_chars
|
||||||
|
+int max_chars
|
||||||
|
+bool deduplicate
|
||||||
|
+Optional[int] max_items
|
||||||
|
}
|
||||||
|
|
||||||
|
class OutputConfig {
|
||||||
|
+Optional[str] domain_key
|
||||||
|
+str storage_format
|
||||||
|
+int max_tokens_per_shard
|
||||||
|
}
|
||||||
|
|
||||||
|
class PipelineConfig {
|
||||||
|
+int version
|
||||||
|
+InputConfig input
|
||||||
|
+dict mask
|
||||||
|
+str mask_default
|
||||||
|
+ProcessingConfig preprocessing
|
||||||
|
+OutputConfig output
|
||||||
|
+from_dict(d) Self
|
||||||
|
}
|
||||||
|
|
||||||
class TrainConfig {
|
class TrainConfig {
|
||||||
+Callable[[], nn.Module] model_fn
|
+Callable[[], nn.Module] model_fn
|
||||||
+str strategy
|
+str strategy
|
||||||
|
|
@ -312,10 +346,29 @@ classDiagram
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace preprocessing {
|
||||||
|
class BaseMaskBuilder {
|
||||||
|
<<abstract>>
|
||||||
|
+build(item, config, tokenizer) Optional[dict]
|
||||||
|
}
|
||||||
|
|
||||||
|
class ChatMaskBuilder {
|
||||||
|
+build(item, config, tokenizer) Optional[dict]
|
||||||
|
}
|
||||||
|
|
||||||
|
class InstructionMaskBuilder {
|
||||||
|
+build(item, config, tokenizer) Optional[dict]
|
||||||
|
}
|
||||||
|
|
||||||
|
class TextMaskBuilder {
|
||||||
|
+build(item, config, tokenizer) Optional[dict]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int]
|
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||||
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
|
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
|
||||||
|
|
@ -346,14 +399,20 @@ classDiagram
|
||||||
+create(name, *args, **kwargs) T
|
+create(name, *args, **kwargs) T
|
||||||
+list_registered() list
|
+list_registered() list
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MaskBuilderFactory {
|
||||||
|
+Registry _registry
|
||||||
|
+register(name) decorator
|
||||||
|
+create(input_type, config, tokenizer) BaseMaskBuilder
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace trainer {
|
namespace trainer {
|
||||||
class Trainer {
|
class Trainer {
|
||||||
+TrainConfig train_config
|
+TrainConfig train_config
|
||||||
+List[TrainCallback] callbacks
|
+List[TrainCallback] callbacks
|
||||||
+train(checkpoint)
|
+train(resume_dir)
|
||||||
+_get_default_callbacks() List[TrainCallback]
|
-_get_default_callbacks() List[TrainCallback]
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContext {
|
class TrainContext {
|
||||||
|
|
@ -383,8 +442,12 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseStrategy {
|
class BaseStrategy {
|
||||||
+Union[Callable, nn.Module] model
|
+Callable model
|
||||||
|
+Optional[BaseExecutor] executor
|
||||||
|
+Optional[Callable] model_fn
|
||||||
|
+dict extra_kwargs
|
||||||
+str device
|
+str device
|
||||||
|
+__call__(batch) Tensor
|
||||||
+compute_loss(batch) Tensor
|
+compute_loss(batch) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -425,6 +488,8 @@ classDiagram
|
||||||
class BaseScheduler {
|
class BaseScheduler {
|
||||||
+get_lr() List[float]
|
+get_lr() List[float]
|
||||||
+step()
|
+step()
|
||||||
|
+state_dict() dict
|
||||||
|
+load_state_dict(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
class SchedulerFactory {
|
class SchedulerFactory {
|
||||||
|
|
@ -436,6 +501,7 @@ classDiagram
|
||||||
class CosineScheduler {
|
class CosineScheduler {
|
||||||
+int warmup_steps
|
+int warmup_steps
|
||||||
+int lr_decay_steps
|
+int lr_decay_steps
|
||||||
|
+int total_steps
|
||||||
+float min_rate
|
+float min_rate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -474,11 +540,11 @@ classDiagram
|
||||||
+int interval
|
+int interval
|
||||||
+bool weight_only
|
+bool weight_only
|
||||||
+Callable save_extra_fn
|
+Callable save_extra_fn
|
||||||
+_save_checkpoint(context)
|
-_save_checkpoint(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
+save_extra(context)$
|
+save_extra(context) dict$
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProgressBarCallback {
|
class ProgressBarCallback {
|
||||||
|
|
@ -491,7 +557,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class MetricLoggerCallback {
|
class MetricLoggerCallback {
|
||||||
+str log_dir
|
+Path log_dir
|
||||||
+int save_interval
|
+int save_interval
|
||||||
+int log_interval
|
+int log_interval
|
||||||
+List[str] metrics
|
+List[str] metrics
|
||||||
|
|
@ -501,7 +567,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ValidationCallback {
|
class ValidationCallback {
|
||||||
+_run_validation(context)
|
-_run_validation(context)
|
||||||
+on_optimizer_step(context)
|
+on_optimizer_step(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -517,7 +583,7 @@ classDiagram
|
||||||
+float weight_decay
|
+float weight_decay
|
||||||
+bool nesterov
|
+bool nesterov
|
||||||
+int ns_steps
|
+int ns_steps
|
||||||
+float adamw_lr
|
+Optional[float] adamw_lr
|
||||||
+tuple adamw_betas
|
+tuple adamw_betas
|
||||||
+float adamw_eps
|
+float adamw_eps
|
||||||
+float adamw_wd
|
+float adamw_wd
|
||||||
|
|
@ -634,7 +700,7 @@ classDiagram
|
||||||
class Task {
|
class Task {
|
||||||
+str task_id
|
+str task_id
|
||||||
+List prompt_ids
|
+List prompt_ids
|
||||||
+int max_tokens
|
+Optional[int] max_tokens
|
||||||
+float temperature
|
+float temperature
|
||||||
+float top_p
|
+float top_p
|
||||||
+int top_k
|
+int top_k
|
||||||
|
|
@ -643,8 +709,8 @@ classDiagram
|
||||||
+int input_tokens
|
+int input_tokens
|
||||||
+int output_tokens
|
+int output_tokens
|
||||||
+float arrival_time
|
+float arrival_time
|
||||||
+float finish_time
|
+Optional[float] finish_time
|
||||||
+Callable stream_callback
|
+Optional[Callable] stream_callback
|
||||||
+int next_pos
|
+int next_pos
|
||||||
+is_finished(stop_ids) bool
|
+is_finished(stop_ids) bool
|
||||||
}
|
}
|
||||||
|
|
@ -671,6 +737,11 @@ classDiagram
|
||||||
+activate(task)
|
+activate(task)
|
||||||
+return_to_waiting(tasks)
|
+return_to_waiting(tasks)
|
||||||
+get_active_tasks() List[Task]
|
+get_active_tasks() List[Task]
|
||||||
|
+has_work() bool
|
||||||
|
+wait_for_tasks(timeout)
|
||||||
|
+get_waiting_tasks() List[Task]
|
||||||
|
+clear_queues()
|
||||||
|
+wake()
|
||||||
+get_stats() Dict
|
+get_stats() Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -760,7 +831,7 @@ classDiagram
|
||||||
|
|
||||||
class ResponseBuilder {
|
class ResponseBuilder {
|
||||||
<<abstract>>
|
<<abstract>>
|
||||||
+prepare(request, engine) Tuple[str, GenContext, List[str]]
|
+prepare(request, tokenizer) Tuple[str, GenContext, List[str]]
|
||||||
+format_stream_start(ctx) List[str]
|
+format_stream_start(ctx) List[str]
|
||||||
+format_chunk(token) str
|
+format_chunk(token) str
|
||||||
+format_stream_end(ctx, stop) List[str]
|
+format_stream_end(ctx, stop) List[str]
|
||||||
|
|
@ -768,7 +839,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class OpenAIResponseBuilder {
|
class OpenAIResponseBuilder {
|
||||||
+prepare(request, engine) Tuple
|
+prepare(request, tokenizer) Tuple
|
||||||
+format_stream_start(ctx) List[str]
|
+format_stream_start(ctx) List[str]
|
||||||
+format_chunk(token) str
|
+format_chunk(token) str
|
||||||
+format_stream_end(ctx, stop) List[str]
|
+format_stream_end(ctx, stop) List[str]
|
||||||
|
|
@ -776,7 +847,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class AnthropicResponseBuilder {
|
class AnthropicResponseBuilder {
|
||||||
+prepare(request, engine) Tuple
|
+prepare(request, tokenizer) Tuple
|
||||||
+format_stream_start(ctx) List[str]
|
+format_stream_start(ctx) List[str]
|
||||||
+format_chunk(token) str
|
+format_chunk(token) str
|
||||||
+format_stream_end(ctx, stop) List[str]
|
+format_stream_end(ctx, stop) List[str]
|
||||||
|
|
@ -787,12 +858,13 @@ classDiagram
|
||||||
+request
|
+request
|
||||||
+engine
|
+engine
|
||||||
+builder: ResponseBuilder
|
+builder: ResponseBuilder
|
||||||
+handle() Union[StreamingResponse, Dict]
|
+async handle() Union[StreamingResponse, Dict]
|
||||||
-_handle_stream(agen, ctx, stops) StreamingResponse
|
-_handle_stream(agen, ctx, stop_sequences) StreamingResponse
|
||||||
-_handle_non_stream(agen, ctx, stops) Dict
|
-async _handle_non_stream(agen, ctx, stop_sequences) Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class StopChecker {
|
class StopChecker {
|
||||||
|
+__init__(sequences)
|
||||||
+check(text) Optional[str]
|
+check(text) Optional[str]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -804,6 +876,12 @@ classDiagram
|
||||||
+int completion_tokens
|
+int completion_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class StopInfo {
|
||||||
|
+Optional[str] matched
|
||||||
|
+str body
|
||||||
|
+str yielded
|
||||||
|
}
|
||||||
|
|
||||||
class app {
|
class app {
|
||||||
<<singleton>>
|
<<singleton>>
|
||||||
+FastAPI app
|
+FastAPI app
|
||||||
|
|
@ -829,14 +907,14 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class Functions {
|
class setup {
|
||||||
<<module>>
|
<<module>>
|
||||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) contextmanager
|
||||||
+get_current_device() str
|
+get_current_device() str
|
||||||
+get_world_size() int
|
+get_world_size() int
|
||||||
+get_rank() int
|
+get_rank() int
|
||||||
+only_on_rank(rank, sync) decorator
|
+only_on_rank(rank, sync=False) decorator
|
||||||
}
|
}
|
||||||
|
|
||||||
class GradientState {
|
class GradientState {
|
||||||
|
|
@ -847,6 +925,7 @@ classDiagram
|
||||||
class AccumOptimizer {
|
class AccumOptimizer {
|
||||||
+Optimizer optimizer
|
+Optimizer optimizer
|
||||||
+GradientState gradient_state
|
+GradientState gradient_state
|
||||||
|
+param_groups (property)
|
||||||
+step(closure)
|
+step(closure)
|
||||||
+zero_grad()
|
+zero_grad()
|
||||||
+state_dict() dict
|
+state_dict() dict
|
||||||
|
|
@ -867,7 +946,7 @@ classDiagram
|
||||||
+prepare(model, optimizer, dataloader, scheduler) tuple
|
+prepare(model, optimizer, dataloader, scheduler) tuple
|
||||||
+accumulate(model) context manager
|
+accumulate(model) context manager
|
||||||
+backward(loss)
|
+backward(loss)
|
||||||
+unwrap_model(model) nn.Module
|
+unwrap_model(model) dict
|
||||||
+sync_gradients (property) bool
|
+sync_gradients (property) bool
|
||||||
+grad_accum_steps (property) int
|
+grad_accum_steps (property) int
|
||||||
}
|
}
|
||||||
|
|
@ -876,14 +955,14 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class DDPExecutor {
|
class DDPExecutor {
|
||||||
+_prepare_model(model) nn.Module
|
-_prepare_model(model) nn.Module
|
||||||
+_no_sync(model) context manager
|
-_no_sync(model) context manager
|
||||||
+unwrap_model(model) nn.Module
|
+unwrap_model(model) dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class FSDPExecutor {
|
class FSDPExecutor {
|
||||||
+_prepare_model(model) nn.Module
|
-_prepare_model(model) nn.Module
|
||||||
+unwrap_model(model) nn.Module
|
+unwrap_model(model) dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class ExecutorFactory {
|
class ExecutorFactory {
|
||||||
|
|
@ -899,11 +978,25 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ColumnParallelLinear {
|
class ColumnParallelLinear {
|
||||||
|
+int in_features
|
||||||
|
+int out_features
|
||||||
|
+int out_features_per_rank
|
||||||
|
+bool gather_results
|
||||||
|
+Parameter weight
|
||||||
|
+Optional[Parameter] bias
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
|
+load_state_dict(state_dict)
|
||||||
}
|
}
|
||||||
|
|
||||||
class RowParallelLinear {
|
class RowParallelLinear {
|
||||||
|
+int in_features
|
||||||
|
+int out_features
|
||||||
|
+int in_features_per_rank
|
||||||
|
+bool reduce_results
|
||||||
|
+Parameter weight
|
||||||
|
+Optional[Parameter] bias
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
|
+load_state_dict(state_dict)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -938,6 +1031,10 @@ classDiagram
|
||||||
AutoModel <|-- EmbeddingEncoder
|
AutoModel <|-- EmbeddingEncoder
|
||||||
BaseConfig <|-- BaseModelConfig
|
BaseConfig <|-- BaseModelConfig
|
||||||
BaseConfig <|-- TrainConfig
|
BaseConfig <|-- TrainConfig
|
||||||
|
BaseConfig <|-- InputConfig
|
||||||
|
BaseConfig <|-- ProcessingConfig
|
||||||
|
BaseConfig <|-- OutputConfig
|
||||||
|
BaseConfig <|-- PipelineConfig
|
||||||
BaseModelConfig <|-- AutoRegressiveLMConfig
|
BaseModelConfig <|-- AutoRegressiveLMConfig
|
||||||
BaseModelConfig <|-- EncoderConfig
|
BaseModelConfig <|-- EncoderConfig
|
||||||
BaseFactory <|-- AutoModel
|
BaseFactory <|-- AutoModel
|
||||||
|
|
@ -950,11 +1047,15 @@ classDiagram
|
||||||
BaseFactory <|-- StoreFactory
|
BaseFactory <|-- StoreFactory
|
||||||
BaseFactory <|-- ExecutorFactory
|
BaseFactory <|-- ExecutorFactory
|
||||||
BaseFactory <|-- ConfigFactory
|
BaseFactory <|-- ConfigFactory
|
||||||
|
BaseFactory <|-- MaskBuilderFactory
|
||||||
BaseExecutor <|-- NoneExecutor
|
BaseExecutor <|-- NoneExecutor
|
||||||
BaseExecutor <|-- DDPExecutor
|
BaseExecutor <|-- DDPExecutor
|
||||||
BaseExecutor <|-- FSDPExecutor
|
BaseExecutor <|-- FSDPExecutor
|
||||||
ResponseBuilder <|-- OpenAIResponseBuilder
|
ResponseBuilder <|-- OpenAIResponseBuilder
|
||||||
ResponseBuilder <|-- AnthropicResponseBuilder
|
ResponseBuilder <|-- AnthropicResponseBuilder
|
||||||
|
BaseMaskBuilder <|-- ChatMaskBuilder
|
||||||
|
BaseMaskBuilder <|-- InstructionMaskBuilder
|
||||||
|
BaseMaskBuilder <|-- TextMaskBuilder
|
||||||
|
|
||||||
%% --- Composition (strong ownership, part destroyed with whole) ---
|
%% --- Composition (strong ownership, part destroyed with whole) ---
|
||||||
KVCache *-- PagePool
|
KVCache *-- PagePool
|
||||||
|
|
@ -994,6 +1095,8 @@ classDiagram
|
||||||
|
|
||||||
%% --- Dependency (uses temporarily) ---
|
%% --- Dependency (uses temporarily) ---
|
||||||
TrainConfig ..> BaseStrategy : selects
|
TrainConfig ..> BaseStrategy : selects
|
||||||
|
PipelineConfig ..> MaskBuilderFactory : selects
|
||||||
|
MaskBuilderFactory ..> BaseMaskBuilder : creates
|
||||||
StrategyFactory ..> BaseStrategy : creates
|
StrategyFactory ..> BaseStrategy : creates
|
||||||
SchedulerFactory ..> BaseScheduler : creates
|
SchedulerFactory ..> BaseScheduler : creates
|
||||||
DatasetFactory ..> BaseDataset : creates
|
DatasetFactory ..> BaseDataset : creates
|
||||||
|
|
@ -1046,7 +1149,8 @@ classDiagram
|
||||||
|
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
|
||||||
|
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, ChatMaskBuilder, InstructionMaskBuilder, TextMaskBuilder, Pipeline, filter_by_length, dedup_signature | Declarative JSON-driven data preprocessing |
|
||||||
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
|
|
@ -1070,14 +1174,14 @@ classDiagram
|
||||||
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
||||||
| **Context** | `TrainContext` | Unified training state bag |
|
| **Context** | `TrainContext` | Unified training state bag |
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor`, `FSDPExecutor` | Gradient accumulation & model distribution |
|
||||||
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||||
|
|
||||||
## Core Relationships
|
## Core Relationships
|
||||||
|
|
||||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
|
1. **Config → Training**: `TrainConfig` holds `model_fn`, `dataset`, `optimizer_fn`, `scheduler_fn`, `parallel_mode`, `executor_kwargs`
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||||
|
|
@ -1089,4 +1193,4 @@ classDiagram
|
||||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
||||||
|
|
||||||
> Document Update Time: 2026-05-28
|
> Document Update Time: 2026-05-30
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ This document describes the data pipeline: from raw text to model input tensors.
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
```
|
```
|
||||||
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
|
||||||
```
|
```
|
||||||
|
|
||||||
## Data Preparation
|
## Data Preparation
|
||||||
|
|
@ -33,14 +33,21 @@ H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS pag
|
||||||
## Dataset Architecture
|
## Dataset Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
|
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
|
||||||
→ StoreFactory.create(detect_format(path))
|
→ BaseDataset.load(load_path, storage_type=None)
|
||||||
|
→ detect_format(load_path)
|
||||||
|
→ StoreFactory.create(storage_type)
|
||||||
|
→ Store.load(load_path)
|
||||||
|
→ H5Store._normalize() / MmapStore._normalize()
|
||||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||||
→ BaseDataset.__getitem__(idx)
|
→ BaseDataset.__getitem__(idx)
|
||||||
→ sliding window [begin, end) via get_index(idx)
|
→ get_index(idx) → [begin, end)
|
||||||
|
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
|
||||||
```
|
```
|
||||||
|
|
||||||
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
|
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
|
||||||
|
|
||||||
|
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
|
||||||
|
|
||||||
## Sampler
|
## Sampler
|
||||||
|
|
||||||
|
|
@ -54,4 +61,4 @@ DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
|
||||||
|
|
||||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-28
|
> Document Update Time: 2026-05-30
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ RoPE is applied **before** KV cache write, not after — otherwise position enco
|
||||||
|
|
||||||
## KVCache System
|
## KVCache System
|
||||||
|
|
||||||
Six classes working together:
|
Six classes (plus two helpers) working together:
|
||||||
|
|
||||||
```
|
```
|
||||||
KVCache (facade)
|
KVCache (facade)
|
||||||
|
|
@ -43,7 +43,8 @@ KVCache (facade)
|
||||||
BaseSamplingStrategy (ABC)
|
BaseSamplingStrategy (ABC)
|
||||||
├── TemperatureStrategy
|
├── TemperatureStrategy
|
||||||
├── TopKStrategy
|
├── TopKStrategy
|
||||||
└── TopPStrategy
|
├── TopPStrategy
|
||||||
|
└── SamplingPipeline
|
||||||
```
|
```
|
||||||
|
|
||||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||||
|
|
@ -73,7 +74,9 @@ Adding a protocol = one builder file, no handler subclassing needed.
|
||||||
InferenceEngine
|
InferenceEngine
|
||||||
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
||||||
├── generate_with_request(req) → same
|
├── generate_with_request(req) → same
|
||||||
└── generate_async(prompt, ...) → AsyncGenerator
|
├── generate_async(prompt, ...) → AsyncGenerator
|
||||||
|
├── get_stats() → Dict
|
||||||
|
└── shutdown()
|
||||||
```
|
```
|
||||||
|
|
||||||
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
||||||
|
|
@ -124,9 +127,9 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||||
| Param | Type | Default | Description |
|
| Param | Type | Default | Description |
|
||||||
|-------|------|---------|-------------|
|
|-------|------|---------|-------------|
|
||||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
|
||||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
|
||||||
| `top_k` | int | 50 | Top-k count |
|
| `top_k` | int | 50 | Top-k count |
|
||||||
|
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||||
|
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
|
||||||
| `max_tokens` | Optional[int] | None | Max generation length |
|
| `max_tokens` | Optional[int] | None | Max generation length |
|
||||||
| `stream` | bool | False | Stream output |
|
| `stream` | bool | False | Stream output |
|
||||||
|
|
||||||
|
|
@ -142,7 +145,8 @@ engine.generate("Hello", stream=True) # -> Generator[str]
|
||||||
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
|
|
||||||
# Async
|
# Async
|
||||||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
|
||||||
|
print(token)
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-05-28
|
> Document Update Time: 2026-05-30
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
|
||||||
|
|
@ -147,10 +147,11 @@ For instruction mode, keys are `"prompt"` and `"response"`.
|
||||||
|
|
||||||
For each message in the `messages` array:
|
For each message in the `messages` array:
|
||||||
|
|
||||||
1. Render through the chat template for that single message
|
1. Prepend BOS token (position 0, always masked)
|
||||||
2. Encode the rendered text, record token span `(start, end, role)`
|
2. Render through the chat template for that single message
|
||||||
3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries
|
3. Encode the rendered text, record token span `(start, end, role)`
|
||||||
4. Fill `loss_mask` from the mask rules
|
4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
|
||||||
|
5. Fill `loss_mask` from the mask rules
|
||||||
|
|
||||||
**Multi-turn example**:
|
**Multi-turn example**:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,14 +36,16 @@ Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_
|
||||||
|
|
||||||
```
|
```
|
||||||
on_train_begin
|
on_train_begin
|
||||||
|
model.train()
|
||||||
on_epoch_begin
|
on_epoch_begin
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
on_batch_begin
|
on_batch_begin
|
||||||
with executor.accumulate(model):
|
with executor.accumulate(model):
|
||||||
loss = strategy(batch)
|
loss = strategy.compute_loss(batch)
|
||||||
|
context.loss = loss.item()
|
||||||
stand_loss = loss / executor.grad_accum_steps
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
executor.backward(stand_loss)
|
executor.backward(stand_loss)
|
||||||
iteration += 1
|
context.iteration += 1
|
||||||
on_batch_end
|
on_batch_end
|
||||||
|
|
||||||
if executor.sync_gradients:
|
if executor.sync_gradients:
|
||||||
|
|
@ -61,9 +63,13 @@ on_train_end
|
||||||
| Hook | Fires | Default callback |
|
| Hook | Fires | Default callback |
|
||||||
|------|-------|-----------------|
|
|------|-------|-----------------|
|
||||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||||
|
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
|
||||||
|
| `on_batch_begin` | Every batch | — |
|
||||||
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
|
||||||
|
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
|
||||||
|
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
|
||||||
|
|
||||||
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||||
|
|
||||||
|
|
@ -77,7 +83,7 @@ $$
|
||||||
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Keys: `input_ids`, `target_ids`
|
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
|
||||||
|
|
||||||
### SFT (Supervised Fine-Tuning)
|
### SFT (Supervised Fine-Tuning)
|
||||||
|
|
||||||
|
|
@ -87,7 +93,7 @@ $$
|
||||||
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Keys: `input_ids`, `target_ids`, `loss_mask`
|
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
|
||||||
|
|
||||||
### DPO (Direct Preference Optimization)
|
### DPO (Direct Preference Optimization)
|
||||||
|
|
||||||
|
|
@ -97,7 +103,7 @@ $$
|
||||||
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||||||
|
|
||||||
### GRPO (Group Relative Policy Optimization)
|
### GRPO (Group Relative Policy Optimization)
|
||||||
|
|
||||||
|
|
@ -111,7 +117,7 @@ $$
|
||||||
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`.
|
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
|
||||||
|
|
||||||
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
|
|
||||||
|
|
@ -122,7 +128,7 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||||
|
|
||||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
||||||
|
|
||||||
## Gradient Checkpointing
|
## Gradient Checkpointing
|
||||||
|
|
||||||
|
|
@ -139,8 +145,8 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
|
||||||
|
|
||||||
```
|
```
|
||||||
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
||||||
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
||||||
└── load(save_dir) broadcasts metadata from rank-0
|
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
|
||||||
```
|
```
|
||||||
|
|
||||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||||
|
|
@ -161,7 +167,7 @@ context = (
|
||||||
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||||
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
|
||||||
|
|
||||||
## Training CLI
|
## Training CLI
|
||||||
|
|
||||||
|
|
@ -170,6 +176,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
|
--parallel_mode=ddp \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
@ -191,4 +198,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-28
|
> Document Update Time: 2026-05-30
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue