AstrAI/assets/docs/design.md

908 lines
30 KiB
Markdown

## 1. Why I Created This Project
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, and the training code is open source!
## 2. System Architecture
```mermaid
classDiagram
namespace config {
class ModelConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+bool tie_weight
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+str attn_type
+str ffn_type
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str moe_topk_method
+load(config_path) ModelConfig
+save(config_path)
}
class TrainConfig {
+nn.Module model
+str strategy
+Dataset dataset
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_size
+int accumulation_steps
+float max_grad_norm
+int start_epoch
+int start_batch
+str ckpt_dir
+int ckpt_interval
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
+bool pin_memory
+int nprocs
+str backend
+str master_addr
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+str device_type
+dict extra_kwargs
+validate()
}
}
namespace dataset {
class BaseDataset {
+int window_size
+int stride
+BaseStorage storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
}
class SEQDataset {
+__getitem__(index) Dict
}
class SFTDataset {
+__getitem__(index) Dict
}
class DPODataset {
+__getitem__(index) Dict
}
class GRPODataset {
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
+keys (property)
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
+int start_epoch
+int start_iter
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
}
namespace serialization {
class Checkpoint {
+dict state_dict
+int epoch
+int iteration
+dict extra
+save(save_dir)
+load(save_dir) Checkpoint
}
}
namespace model {
class AutoModel {
+ModelConfig config
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict
+load_state_dict(state_dict)
+state_dict()
}
class DecoderBlock {
+nn.Module attention # GQA or MLA via AttnFactory
+RMSNorm input_norm
+nn.Module mlp # MLP or DeepSeekMoE via FFNFactory
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
}
class GQA {
+int n_heads
+int n_kv_heads
+int head_dim
+int n_rep
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj
+Linear gate # only if use_gated_attention
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLA {
+int n_heads
+int n_kv_heads
+int head_dim
+int kv_lora_rank
+int qk_nope_head_dim
+int qk_rope_head_dim
+int n_rep
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+Linear gate # only if use_gated_attention
+RMSNorm kv_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLP {
+Linear up, gate, down
+forward(x) Tensor
}
class DeepSeekMoE {
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str topk_method
+Linear router
+ModuleList shared_experts
+ModuleList routed_experts
+forward(x) Tensor
}
class AttnFactory {
+create(attn_type, **kwargs) nn.Module
}
class FFNFactory {
+create(ffn_type, dim, dim_ffn, **kwargs) nn.Module
}
class RMSNorm {
+Parameter weight
+float norm_eps
+forward(x) Tensor
}
class Linear {
+Parameter weight
+Optional[Parameter] bias # only if bias=True
+forward(x) Tensor
}
class RotaryEmbedding {
+int dim
+int max_len
+float base
+forward(x, position_ids=None) Tensor
}
class Embedding {
+Parameter weight
+forward(x) Tensor
}
}
namespace tokenize {
class AutoTokenizer {
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
+save_pretrained(save_path)
}
class ChatTemplate {
+String template_str
+render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate
}
}
namespace factory {
class Registry {
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List[str]
}
class BaseFactory {
+Registry _registry
+register(name, category, priority) decorator
+create(name, *args, **kwargs) T
+list_registered() list
}
}
namespace trainer {
class Trainer {
+TrainConfig train_config
+List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback]
}
class TrainContext {
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+int epoch
+int iteration
+float loss
+int world_size
+int rank
}
class TrainContextBuilder {
+TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder
+build() TrainContext
}
class BaseStrategy {
+nn.Module model
+str device
+compute_loss(batch) Tensor
}
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(model, train_type, device, **kwargs) BaseStrategy
}
class SEQStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class SFTStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class DPOStrategy {
+nn.Module ref_model
+float beta
+str reduction
+compute_loss(batch) Tensor
}
class GRPOStrategy {
+nn.Module ref_model
+float clip_eps
+float kl_coef
+int group_size
+str reduction
+int sync_interval
+compute_loss(batch) Tensor
}
class BaseScheduler {
+get_lr() List[float]
+step()
}
class SchedulerFactory {
+Registry _registry
+register(name) decorator
+create(optimizer, schedule_type, **kwargs) BaseScheduler
}
class CosineScheduler {
+int warmup_steps
+int lr_decay_steps
+float min_rate
}
class SGDRScheduler {
+int warmup_steps
+int cycle_length
+float min_rate
+int t_mult
}
class TrainCallback {
+on_train_begin(context)
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_step_end(context)
}
class CheckpointCallback {
+str save_dir
+int interval
+_save_checkpoint(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
}
class ProgressBarCallback {
+int num_epoch
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
}
class MetricLoggerCallback {
+str log_dir
+int save_interval
+on_batch_end(context)
+on_train_end(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
}
namespace inference {
class InferenceEngine {
+nn.Module model
+AutoTokenizer tokenizer
+InferenceScheduler scheduler
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict
+shutdown()
}
class Executor {
+AutoModel model
+AutoTokenizer tokenizer
+KVCache page_cache
+execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int]
}
class InferenceScheduler {
+KVCache _page_cache
+Executor _executor
+TaskManager _task_mgr
+bool _running
+Thread _loop_thread
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
+stop()
+get_stats() Dict
}
class Allocator {
+int _free_mask
+List[int] _refs
+OrderedDict _lru
+alloc() int
+free(idx, keep_cached)
+inc_ref(idx)
+touch(idx)
+ref_count(idx) int
}
class PrefixCache {
+int _page_size
+evict(page_idx)
+has_page(idx) bool
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class PagePool {
-Allocator _alloc
-PrefixCache _prefix
+alloc() int
+free(idx)
+inc_ref(idx)
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class Storage {
+int n_layers
+int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache
+Tensor v_cache
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
}
class KVCache {
-PagePool _pool
-Storage _storage
-TaskTable _table
+int page_size
+task_alloc(task_id, prompt_ids) bool
+task_free(task_id)
+task_extend(task_id, pos) bool
+task_cached(task_id) int
+task_record_hashes(task_id, prompt_ids, start_logical_page)
+make_table_tensor(task_ids, device) Tensor
+bind(page_table, total_len) KvcacheView
}
class KvcacheView {
-Storage _storage
+Tensor _page_table
+int _total_len
+write(layer_id, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
class TaskTable {
+set(task_id, page_table, cached)
+get(task_id) List[int]
+get_cached(task_id) int
+get_ref(task_id) List[int]
+pop(task_id) Tuple[List[int], int]
+table_tensor(task_ids, device) Tensor
}
class Task {
+str task_id
+List prompt_ids
+int max_tokens
+float temperature
+float top_p
+int top_k
+TaskStatus status
+List output_ids
+int input_tokens
+int output_tokens
+float arrival_time
+float finish_time
+Callable stream_callback
+int next_pos
+is_finished(stop_ids) bool
}
class TaskStatus {
<<enumeration>>
PENDING
RUNNING
FINISHED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+int top_k
+float top_p
+float temperature
+Optional[int] max_tokens
+bool stream
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
}
class GenerateResult {
+List[Tuple[int, str]] tokens
+List[str] results
+List[bool] _done
+append(token, idx)
+get_results() List[str]
+pop_all() List[Tuple[int, str]]
+wait(timeout) bool
+wait_completion(timeout)
}
class ChatMessage {
+str role
+str content
}
class ChatCompletionRequest {
+List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[str] stop
+Optional[int] n
}
class AnthropicMessage {
+str role
+Union[str, List[Dict]] content
}
class MessagesRequest {
+List[AnthropicMessage] messages
+Optional[str] system
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[List[str]] stop_sequences
}
class ProtocolHandler {
<<abstract>>
+build_prompt() str
+create_response_id() str
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict]
}
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+List[str] stop_sequences
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
}
class StopChecker {
+check(text) Optional[str]
+trim(text, matched) str
}
class StreamContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
}
class app {
<<singleton>>
+FastAPI app
}
}
namespace parallel {
class Functions {
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
}
class ParallelModel {
+dist.ProcessGroup process_group
+int rank
+int world_size
}
class ColumnParallelLinear {
+forward(x) Tensor
}
class RowParallelLinear {
+forward(x) Tensor
}
}
%% Relationships — UML notation: <|-- generalization, *-- composition, o-- aggregation, --> association, ..> dependency
%% --- Generalization (inheritance) ---
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
BaseFactory <|-- FFNFactory
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
KVCache *-- Allocator
KVCache *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
TrainContextBuilder *-- TrainContext
Transformer *-- DecoderBlock
Transformer *-- RotaryEmbedding
Transformer *-- Embedding
DecoderBlock *-- RMSNorm
BaseDataset *-- BaseStorage
%% --- Aggregation (weak ownership) ---
AutoModel o-- ModelConfig
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
AutoTokenizer o-- ChatTemplate
KvcacheView o-- Storage
BaseFactory o-- Registry
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
SchedulerFactory ..> BaseScheduler : creates
DatasetFactory ..> BaseDataset : creates
CallbackFactory ..> TrainCallback : creates
AttnFactory ..> GQA : creates
AttnFactory ..> MLA : creates
FFNFactory ..> MLP : creates
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
Trainer ..> TrainContextBuilder : uses
Trainer ..> Functions : spawns
TrainContextBuilder ..> StrategyFactory : uses
TrainContextBuilder ..> ResumableDistributedSampler : creates
Checkpoint ..> Checkpoint : serializes
CheckpointCallback ..> Checkpoint : creates
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
DPOStrategy --> Transformer
GRPOStrategy --> Transformer
InferenceScheduler --> Task
InferenceScheduler --> TaskStatus
Task --> TaskStatus
InferenceEngine --> Transformer
Executor --> Transformer
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
```
### Module Overview
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5, save_json, load_json, create_storage, detect_format | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample, ChatMessage, ChatCompletionRequest, AnthropicMessage, MessagesRequest, OpenAIHandler, AnthropicHandler, ProtocolHandler, StreamContext, StopChecker, app, run_server | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, only_on_rank, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory[T] | Generic component registration with decorator pattern |
### Design Patterns
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, gradient clipping, metrics) |
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | `handle()` template with stream/non-stream branches, protocol-specific format hooks |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage`, `_STORAGE_REGISTRY` | Format-agnostic data access with registry-dispatch (HDF5 / JSON) |
### Core Relationships
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 and JSON loading via `BaseStorage` (`H5Storage` / `JSONStorage`) with `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
## 3. Training Process
The common training process for large language models (LLM) typically includes three stages: **Pre-training (SEQ)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (DPO/GRPO)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies.
### Core Formulas
**Pre-training (SEQ):**
$$
L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
**SFT:**
$$
L_{\text{SFT}} = - \sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$
**DPO:**
$$
L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]
$$
**GRPO:**
GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective:
$$
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
$$
Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards.
$$
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
$$
The KL divergence term uses mean squared error approximation:
$$
L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right]
$$
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-05-15