# AstrAI Architecture ## Class Diagram ```mermaid classDiagram namespace config { class BaseConfig { +to_dict() Dict +from_dict(d) Self } class BaseModelConfig { +Optional[str] model_type +from_file(config_path) Self +to_file(config_path) } class AutoRegressiveLMConfig { +int vocab_size +int dim +int n_layers +float norm_eps +int dim_ffn +Optional[bool] tie_weight +Optional[dict] rope_scaling +int max_len +float rope_theta +str attn_type +int n_heads +int n_kv_heads +bool use_qk_norm +bool use_gated_attention +Optional[int] kv_lora_rank +Optional[int] qk_nope_head_dim +Optional[int] qk_rope_head_dim +str ffn_type +int n_routed_experts +int n_shared_experts +int n_activated_experts +Optional[str] topk_method } class EncoderConfig { +int vocab_size +int dim +int n_layers +float norm_eps +int dim_ffn +int max_len +float rope_theta +int n_heads +int n_kv_heads +bool use_qk_norm +bool use_gated_attention +Optional[dict] rope_scaling +Optional[str] pooling_type +Optional[bool] normalize_embeddings } class ConfigFactory { +Registry _registry +register(name) decorator +load(raw) BaseConfig } class TrainConfig { +nn.Module model +str strategy +Dataset dataset +Callable optimizer_fn +Callable scheduler_fn +int n_epoch +int batch_per_device +int grad_accum_steps +float max_grad_norm +list gradient_checkpointing_modules +int start_epoch +int start_batch +str ckpt_dir +int ckpt_interval +str log_dir +int log_interval +List[str] metrics +Optional[LoRAConfig] lora +int random_seed +int num_workers +Optional[int] prefetch_factor +bool pin_memory +int nprocs +str backend +str master_addr +str master_port +str start_method +str device_type +Optional[Dataset] val_dataset +int val_step +str parallel_mode +dict executor_kwargs +dict extra_kwargs +validate() } } namespace dataset { class BaseDataset { +int window_size +int stride +Optional[BaseStorage] storage +load(load_path, storage_type, tokenizer) +__getitem__(index) +__len__() } class SEQDataset { +__getitem__(index) Dict } class SFTDataset { +__getitem__(index) Dict } class DPODataset { +__getitem__(index) Dict } class GRPODataset { +__getitem__(index) Dict } class BaseSegmentFetcher { +List[Tensor] segments +List[int] cum_lengths +int total_length +fetch_data(begin_idx, end_idx) Tensor } class BaseStorage { +MultiSegmentFetcher _fetcher +keys (property) +load(load_path, tokenizer) +fetch(begin, end, keys) +__len__() } class H5Storage { +load(load_path, tokenizer) +fetch(begin, end, keys) Dict +keys() List } class JSONStorage { +load(load_path, tokenizer) +fetch(begin, end, keys) Dict +keys() List } class MultiSegmentFetcher { +Dict multi_fetchers +List multi_keys +key_fetch(begin_idx, end_idx, keys) Dict +fetch_data(begin_idx, end_idx) Dict } class ResumableDistributedSampler { +int epoch +int iter } class StorageFactory { +Registry _registry +register(name) decorator +create(storage_type) BaseStorage } class DatasetFactory { +Registry _registry +register(name) decorator +create(train_type, window_size, stride) BaseDataset +load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset } } namespace serialization { class Checkpoint { +dict state_dict +int epoch +int iteration +dict extra +dict meta +save(save_dir) +load(save_dir) Checkpoint } } namespace model { class AutoModel { +BaseModelConfig config +Registry _registry +register(model_type) decorator +get_component_class(model_type) Type +from_pretrained(path, disable_random_init, strict) nn.Module +save_pretrained(save_directory) +to(*args, **kwargs) Self } class AutoRegressiveLM { +AutoRegressiveLMConfig config +RotaryEmbedding rotary_embedding +Embedding embed_tokens +ModuleList layers +RMSNorm norm +Linear lm_head +forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor] +load_state_dict(state_dict) +state_dict() } class EmbeddingEncoder { +EncoderConfig config +RotaryEmbedding rotary_embedding +Embedding embed_tokens +ModuleList layers +RMSNorm norm +str pooling_type +bool normalize_embeddings +forward(input_ids, input_mask, position_ids) Tensor +load_state_dict(state_dict) } class DecoderBlock { +nn.Module attention # GQA or MLA via AttnFactory +RMSNorm input_norm +nn.Module mlp # MLP or DeepSeekMoE via FFNFactory +RMSNorm post_attention_norm +forward(x, rotary_emb, attention_mask, paged_cache) Tensor } class GQA { +int n_heads +int n_kv_heads +int head_dim +int n_rep +int layer_id +bool use_qk_norm +bool use_gated_attention +Linear q_proj, k_proj, v_proj, o_proj +Linear gate # only if use_gated_attention +RMSNorm q_norm, k_norm # only if use_qk_norm +forward(x, rotary_emb, attn_mask, paged_cache) Tensor } class MLA { +int n_heads +int n_kv_heads +int head_dim +int kv_lora_rank +int qk_nope_head_dim +int qk_rope_head_dim +int n_rep +int layer_id +bool use_qk_norm +bool use_gated_attention +Linear q_proj, kv_a_proj, kv_b_proj +Linear o_proj +Linear gate # only if use_gated_attention +RMSNorm kv_norm +RMSNorm q_norm, k_norm # only if use_qk_norm +forward(x, rotary_emb, attn_mask, paged_cache) Tensor } class MLP { +Linear up, gate, down +forward(x) Tensor } class DeepSeekMoE { +int dim +int n_routed_experts +int n_shared_experts +int n_activated_experts +str topk_method +Linear router +ModuleList shared_experts +ModuleList routed_experts +forward(x) Tensor } class AttnFactory { +create(attn_type, **kwargs) nn.Module } class FFNFactory { +create(ffn_type, dim, dim_ffn, **kwargs) nn.Module } class RMSNorm { +Parameter weight +float norm_eps +tuple normalized_shape +forward(x) Tensor } class Linear { +Parameter weight +Optional[Parameter] bias # only if bias=True +forward(x) Tensor } class RotaryEmbedding { +int dim +int max_len +float base +forward(x, position_ids=None) Tensor } class Embedding { +Parameter weight +forward(x) Tensor } } namespace tokenize { class AutoTokenizer { +vocab_size int +encode(tokens, out_ids, add_special_tokens) List[int] +decode(tokens, skip_special_tokens) str +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) +apply_chat_template(messages, tokenize) Union[str, List[int]] +set_chat_template(template) +load(path) +from_pretrained(path) AutoTokenizer +save_pretrained(save_path) } class ChatTemplate { +String template_str +render(messages, system_prompt, **extra_variables) str +from_string(template) ChatTemplate } } namespace factory { class Registry { +Dict _entries +register(name, component_cls, category, priority) +get(name) Type +list_names() List[str] } class BaseFactory { +Registry _registry +register(name, category, priority) decorator +create(name, *args, **kwargs) T +list_registered() list } } namespace trainer { class Trainer { +TrainConfig train_config +List[TrainCallback] callbacks +train(checkpoint) +_get_default_callbacks() List[TrainCallback] } class TrainContext { +nn.Module model +BaseStrategy strategy +DataLoader dataloader +OptimizerProtocol optimizer +SchedulerProtocol scheduler +Checkpoint checkpoint +TrainConfig config +BaseExecutor executor +int epoch +int iteration +float loss +DataLoader val_dataloader +float val_loss +int world_size +int rank +dict kwargs } class TrainContextBuilder { +TrainConfig config +with_checkpoint(checkpoint) TrainContextBuilder +build() TrainContext } class BaseStrategy { +Union[Callable, nn.Module] model +str device +compute_loss(batch) Tensor } class StrategyFactory { +Registry _registry +register(name) decorator +create(train_type, model, device, **kwargs) BaseStrategy } class SEQStrategy { +float label_smoothing +compute_loss(batch) Tensor } class SFTStrategy { +float label_smoothing +compute_loss(batch) Tensor } class DPOStrategy { +nn.Module ref_model +float beta +str reduction +compute_loss(batch) Tensor } class GRPOStrategy { +nn.Module ref_model +float clip_eps +float kl_coef +int group_size +str reduction +int sync_interval +compute_loss(batch) Tensor +sync_ref_model() } class BaseScheduler { +get_lr() List[float] +step() } class SchedulerFactory { +Registry _registry +register(name) decorator +create(optimizer, schedule_type, **kwargs) BaseScheduler } class CosineScheduler { +int warmup_steps +int lr_decay_steps +float min_rate } class SGDRScheduler { +int warmup_steps +int cycle_length +float min_rate +int t_mult } class TrainCallback { <> +on_train_begin(context) +on_train_end(context) +on_epoch_begin(context) +on_epoch_end(context) +on_batch_begin(context) +on_batch_end(context) +on_optimizer_step(context) +on_error(context) } class GradientClippingCallback { +float max_grad_norm +on_optimizer_step(context) } class GradientCheckpointingCallback { +tuple modules +on_train_begin(context) +on_train_end(context) } class CheckpointCallback { +str save_dir +int interval +bool weight_only +Callable state_dict_fn +Callable save_extra_fn +Callable load_extra_fn +_save_checkpoint(context) +on_train_begin(context) +on_batch_end(context) +on_train_end(context) +on_error(context) +save_extra(context)$ +load_extra(extra, context)$ } class ProgressBarCallback { +int num_epoch +int log_interval +IO file +on_epoch_begin(context) +on_batch_end(context) +on_epoch_end(context) } class MetricLoggerCallback { +str log_dir +int save_interval +int log_interval +List[str] metrics +on_batch_end(context) +on_train_end(context) +on_error(context) } class ValidationCallback { +_run_validation(context) +on_optimizer_step(context) } class CallbackFactory { +Registry _registry +register(name) decorator +create(name, **kwargs) TrainCallback } class Muon { +float lr +float momentum +float weight_decay +int ns_steps +step(closure) Optional[float] } } namespace inference { class InferenceEngine { +nn.Module model +AutoTokenizer tokenizer +InferenceScheduler scheduler +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]] +generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator +get_stats() Dict +shutdown() } class Executor { +AutoModel model +AutoTokenizer tokenizer +KVCache page_cache +execute_prefill(tasks, prompt_len, start_pos) +execute_decode(tasks) List[int] } class InferenceScheduler { +KVCache _page_cache +Executor _executor +TaskManager _task_mgr +bool _running +Thread _loop_thread +int max_seq_len +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +remove_task(task_id) +start() +stop() +get_stats() Dict } class Allocator { +int _free_mask +List[int] _refs +OrderedDict _lru +alloc() int +free(idx, keep_cached) +inc_ref(idx) +touch(idx) +ref_count(idx) int } class PrefixCache { +int _page_size +evict(page_idx) +has_page(idx) bool +lookup(token_ids) List[int] +record(page_idx, token_ids, logical_page_idx) } class PagePool { -Allocator _alloc -PrefixCache _prefix +alloc() int +free(idx) +inc_ref(idx) +lookup(token_ids) List[int] +record(page_idx, token_ids, logical_page_idx) } class Storage { +int page_size +Tensor k_cache +Tensor v_cache +write(layer_id, page_table, start_pos, k, v) +gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor] } class KVCache { -PagePool _pool -Storage _storage -TaskTable _table +int page_size +task_alloc(task_id, prompt_ids) bool +task_free(task_id) +task_extend(task_id, pos) bool +task_cached(task_id) int +task_record_hashes(task_id, prompt_ids, start_logical_page) +make_table_tensor(task_ids, device) Tensor +bind(page_table, total_len) KvcacheView } class KvcacheView { -Storage _storage +Tensor _page_table +int _total_len +write(layer_id, k, v) +gather(layer_id) Tuple[Tensor, Tensor] } class TaskTable { +set(task_id, page_table, cached) +get(task_id) List[int] +get_cached(task_id) int +get_ref(task_id) List[int] +pop(task_id) Tuple[List[int], int] +table_tensor(task_ids, device) Tensor } class Task { +str task_id +List prompt_ids +int max_tokens +float temperature +float top_p +int top_k +TaskStatus status +List output_ids +int input_tokens +int output_tokens +float arrival_time +float finish_time +Callable stream_callback +int next_pos +is_finished(stop_ids) bool } class TaskStatus { <> PENDING RUNNING FINISHED ABORTED } class TaskManager { +AutoTokenizer tokenizer +Deque waiting_queue +List active_tasks +add_task(prompt, **kwargs) str +remove_task(task_id) List[Task] +remove_finished_tasks(stop_ids) List[Task] +pull_candidates(n) List[Task] +activate(task) +return_to_waiting(tasks) +get_active_tasks() List[Task] } class GenerationRequest { +List[Dict] messages +int top_k +float top_p +float temperature +Optional[int] max_tokens +bool stream } class BaseSamplingStrategy { <> +apply(logits, filter_value) Tensor } class TemperatureStrategy { +float temperature +apply(logits, filter_value) Tensor } class TopKStrategy { +int top_k +apply(logits, filter_value) Tensor } class TopPStrategy { +float top_p +apply(logits, filter_value) Tensor } class SamplingPipeline { +List[BaseSamplingStrategy] strategies +apply(logits, filter_value) Tensor +sample(logits, filter_value) Tensor } class GenerateResult { +List[Tuple[int, str]] tokens +List[str] results +List[bool] _done +append(token, idx) +get_results() List[str] +pop_all() List[Tuple[int, str]] +wait(timeout) bool +wait_completion(timeout) } class ChatMessage { +str role +str content } class ChatCompletionRequest { +str model +List[ChatMessage] messages +Optional[float] temperature +Optional[float] top_p +Optional[int] top_k +Optional[int] max_tokens +Optional[bool] stream +Optional[Union[str, List[str]]] stop +Optional[int] n +Optional[float] presence_penalty +Optional[float] frequency_penalty +Optional[Dict[int, float]] logit_bias +Optional[str] user } class AnthropicMessage { +str role +Union[str, List[Dict]] content } class MessagesRequest { +str model +List[AnthropicMessage] messages +Optional[str] system +Optional[float] temperature +Optional[float] top_p +Optional[int] top_k +int max_tokens +Optional[bool] stream +Optional[List[str]] stop_sequences } class ResponseBuilder { <> +prepare(request, engine) Tuple[str, GenContext, List[str]] +format_stream_start(ctx) List[str] +format_chunk(token) str +format_stream_end(ctx, stop) List[str] +format_response(ctx, content, stop) Dict } class OpenAIResponseBuilder { +prepare(request, engine) Tuple +format_stream_start(ctx) List[str] +format_chunk(token) str +format_stream_end(ctx, stop) List[str] +format_response(ctx, content, stop) Dict } class AnthropicResponseBuilder { +prepare(request, engine) Tuple +format_stream_start(ctx) List[str] +format_chunk(token) str +format_stream_end(ctx, stop) List[str] +format_response(ctx, content, stop) Dict } class ProtocolHandler { +request +engine +builder: ResponseBuilder +handle() Union[StreamingResponse, Dict] -_handle_stream(agen, ctx, stops) StreamingResponse -_handle_non_stream(agen, ctx, stops) Dict } class StopChecker { +check(text) Optional[str] } class GenContext { +str resp_id +int created +str model +int prompt_tokens +int completion_tokens } class app { <> +FastAPI app } } namespace protocols { class OptimizerProtocol { <> +step(closure) +zero_grad() +state_dict() dict +load_state_dict(d) } class SchedulerProtocol { <> +step() +state_dict() dict +load_state_dict(d) +get_last_lr() } } namespace parallel { class Functions { <> +spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) +get_current_device() str +get_world_size() int +get_rank() int +only_on_rank(rank, sync) decorator } class GradientState { +int num_steps +sync_gradients (property) bool } class AccumOptimizer { +Optimizer optimizer +GradientState gradient_state +step(closure) +zero_grad() +state_dict() dict +load_state_dict(d) } class AccumScheduler { +LRScheduler scheduler +GradientState gradient_state +step() +state_dict() dict +load_state_dict(d) +get_last_lr() } class BaseExecutor { +GradientState gradient_state +prepare(model, optimizer, dataloader, scheduler) tuple +accumulate(model) context manager +backward(loss) +unwrap_model(model) nn.Module +sync_gradients (property) bool +grad_accum_steps (property) int } class NoneExecutor { } class DDPExecutor { +_prepare_model(model) nn.Module +_no_sync(model) context manager +unwrap_model(model) nn.Module } class FSDPExecutor { +_prepare_model(model) nn.Module +unwrap_model(model) nn.Module } class ExecutorFactory { +Registry _registry +register(name) decorator +create(parallel_mode, **kwargs) BaseExecutor } class ParallelModel { +dist.ProcessGroup process_group +int rank +int world_size } class ColumnParallelLinear { +forward(x) Tensor } class RowParallelLinear { +forward(x) Tensor } } %% Relationships — UML notation: <|-- generalization, *-- composition, o-- aggregation, --> association, ..> dependency %% --- Generalization (inheritance) --- BaseStrategy <|-- SEQStrategy BaseStrategy <|-- SFTStrategy BaseStrategy <|-- DPOStrategy BaseStrategy <|-- GRPOStrategy BaseScheduler <|-- CosineScheduler BaseScheduler <|-- SGDRScheduler TrainCallback <|-- GradientClippingCallback TrainCallback <|-- GradientCheckpointingCallback TrainCallback <|-- CheckpointCallback TrainCallback <|-- ProgressBarCallback TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- ValidationCallback BaseDataset <|-- SEQDataset BaseDataset <|-- SFTDataset BaseDataset <|-- DPODataset BaseDataset <|-- GRPODataset BaseStorage <|-- H5Storage BaseStorage <|-- JSONStorage BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopPStrategy BaseSamplingStrategy <|-- SamplingPipeline ParallelModel <|-- RowParallelLinear ParallelModel <|-- ColumnParallelLinear AutoModel <|-- AutoRegressiveLM AutoModel <|-- EmbeddingEncoder BaseConfig <|-- BaseModelConfig BaseConfig <|-- TrainConfig BaseModelConfig <|-- AutoRegressiveLMConfig BaseModelConfig <|-- EncoderConfig BaseFactory <|-- AutoModel BaseFactory <|-- AttnFactory BaseFactory <|-- FFNFactory BaseFactory <|-- DatasetFactory BaseFactory <|-- StrategyFactory BaseFactory <|-- SchedulerFactory BaseFactory <|-- CallbackFactory BaseFactory <|-- StorageFactory BaseFactory <|-- ExecutorFactory BaseFactory <|-- ConfigFactory BaseExecutor <|-- NoneExecutor BaseExecutor <|-- DDPExecutor BaseExecutor <|-- FSDPExecutor ResponseBuilder <|-- OpenAIResponseBuilder ResponseBuilder <|-- AnthropicResponseBuilder %% --- Composition (strong ownership, part destroyed with whole) --- KVCache *-- PagePool KVCache *-- Storage KVCache *-- TaskTable InferenceEngine *-- InferenceScheduler InferenceScheduler *-- KVCache InferenceScheduler *-- Executor InferenceScheduler *-- TaskManager AutoRegressiveLM *-- DecoderBlock AutoRegressiveLM *-- RotaryEmbedding AutoRegressiveLM *-- Embedding EmbeddingEncoder *-- DecoderBlock EmbeddingEncoder *-- RotaryEmbedding EmbeddingEncoder *-- Embedding DecoderBlock *-- RMSNorm ChatCompletionRequest *-- ChatMessage MessagesRequest *-- AnthropicMessage BaseFactory *-- Registry BaseExecutor *-- GradientState AccumOptimizer o-- GradientState AccumScheduler o-- GradientState %% --- Aggregation (weak ownership) --- AutoModel o-- BaseModelConfig AutoTokenizer o-- ChatTemplate PagePool o-- Allocator PagePool o-- PrefixCache Trainer o-- TrainCallback TrainContext o-- BaseStrategy TrainContext o-- BaseScheduler TrainContext o-- Checkpoint TrainContext o-- BaseExecutor KvcacheView o-- Storage SamplingPipeline o-- BaseSamplingStrategy BaseDataset o-- BaseStorage %% --- Dependency (uses temporarily) --- TrainConfig ..> BaseStrategy : selects StrategyFactory ..> BaseStrategy : creates SchedulerFactory ..> BaseScheduler : creates DatasetFactory ..> BaseDataset : creates CallbackFactory ..> TrainCallback : creates AttnFactory ..> GQA : creates AttnFactory ..> MLA : creates FFNFactory ..> MLP : creates FFNFactory ..> DeepSeekMoE : creates DecoderBlock ..> AttnFactory : uses DecoderBlock ..> FFNFactory : uses StorageFactory ..> H5Storage : creates StorageFactory ..> JSONStorage : creates ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> EncoderConfig : creates ExecutorFactory ..> NoneExecutor : creates ExecutorFactory ..> DDPExecutor : creates ExecutorFactory ..> FSDPExecutor : creates TrainContextBuilder ..> ExecutorFactory : creates Trainer ..> TrainContextBuilder : uses TrainContextBuilder ..> TrainContext : creates Trainer ..> Functions : spawns TrainContextBuilder ..> StrategyFactory : uses TrainContextBuilder ..> ResumableDistributedSampler : creates Checkpoint ..> Checkpoint : serializes CheckpointCallback ..> Checkpoint : creates KVCache ..> KvcacheView : binds InferenceEngine ..> GenerationRequest : uses InferenceEngine ..> GenerateResult : creates OpenAIResponseBuilder ..> ChatCompletionRequest : receives AnthropicResponseBuilder ..> MessagesRequest : receives ProtocolHandler ..> StopChecker : creates ProtocolHandler ..> GenContext : creates %% --- Association (general usage) --- Trainer --> TrainConfig DPOStrategy --> AutoModel GRPOStrategy --> AutoModel InferenceScheduler --> Task InferenceScheduler --> TaskStatus Task --> TaskStatus InferenceEngine --> AutoModel Executor --> AutoModel Executor --> AutoTokenizer TaskManager --> AutoTokenizer MultiSegmentFetcher --> BaseSegmentFetcher ``` ## Module Overview | Module | Components | Description | |--------|------------|-------------| | **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | | **astrai.factory** | Registry, BaseFactory[T] | Component registration | | **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers | ## Design Patterns | Pattern | Classes | Purpose | |---------|---------|---------| | **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | | **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks | | **Builder** | `TrainContextBuilder` | Chain-building training context | | **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Context** | `TrainContext` | Unified training state bag | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | ## Core Relationships 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers > Document Update Time: 2026-05-24