35 KiB
35 KiB
AstrAI Architecture
Class Diagram
classDiagram
namespace config {
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
}
class BaseModelConfig {
+Optional[str] model_type
+from_file(config_path) Self
+to_file(config_path)
}
class AutoRegressiveLMConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+Optional[bool] tie_weight
+Optional[dict] rope_scaling
+int max_len
+float rope_theta
+str attn_type
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+str ffn_type
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+Optional[str] topk_method
}
class EncoderConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[dict] rope_scaling
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
class ConfigFactory {
+Registry _registry
+register(name) decorator
+load(raw) BaseConfig
}
class TrainConfig {
+nn.Module model
+str strategy
+Dataset dataset
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_per_device
+int grad_accum_steps
+float max_grad_norm
+list gradient_checkpointing_modules
+int start_epoch
+int start_batch
+str ckpt_dir
+int ckpt_interval
+str log_dir
+int log_interval
+List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
+bool pin_memory
+int nprocs
+str backend
+str master_addr
+str master_port
+str start_method
+str device_type
+Optional[Dataset] val_dataset
+int val_step
+str parallel_mode
+dict executor_kwargs
+dict extra_kwargs
+validate()
}
}
namespace dataset {
class BaseDataset {
+int window_size
+int stride
+Optional[BaseStorage] storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
}
class SEQDataset {
+__getitem__(index) Dict
}
class SFTDataset {
+__getitem__(index) Dict
}
class DPODataset {
+__getitem__(index) Dict
}
class GRPODataset {
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
+keys (property)
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
+int epoch
+int iter
}
class StorageFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) BaseStorage
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
}
}
namespace serialization {
class Checkpoint {
+dict state_dict
+int epoch
+int iteration
+dict extra
+dict meta
+save(save_dir)
+load(save_dir) Checkpoint
}
}
namespace model {
class AutoModel {
+BaseModelConfig config
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
+from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
class AutoRegressiveLM {
+AutoRegressiveLMConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict)
+state_dict()
}
class EmbeddingEncoder {
+EncoderConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+str pooling_type
+bool normalize_embeddings
+forward(input_ids, input_mask, position_ids) Tensor
+load_state_dict(state_dict)
}
class DecoderBlock {
+nn.Module attention # GQA or MLA via AttnFactory
+RMSNorm input_norm
+nn.Module mlp # MLP or DeepSeekMoE via FFNFactory
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
}
class GQA {
+int n_heads
+int n_kv_heads
+int head_dim
+int n_rep
+int layer_id
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, k_proj, v_proj, o_proj
+Linear gate # only if use_gated_attention
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLA {
+int n_heads
+int n_kv_heads
+int head_dim
+int kv_lora_rank
+int qk_nope_head_dim
+int qk_rope_head_dim
+int n_rep
+int layer_id
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+Linear gate # only if use_gated_attention
+RMSNorm kv_norm
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLP {
+Linear up, gate, down
+forward(x) Tensor
}
class DeepSeekMoE {
+int dim
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str topk_method
+Linear router
+ModuleList shared_experts
+ModuleList routed_experts
+forward(x) Tensor
}
class AttnFactory {
+create(attn_type, **kwargs) nn.Module
}
class FFNFactory {
+create(ffn_type, dim, dim_ffn, **kwargs) nn.Module
}
class RMSNorm {
+Parameter weight
+float norm_eps
+tuple normalized_shape
+forward(x) Tensor
}
class Linear {
+Parameter weight
+Optional[Parameter] bias # only if bias=True
+forward(x) Tensor
}
class RotaryEmbedding {
+int dim
+int max_len
+float base
+forward(x, position_ids=None) Tensor
}
class Embedding {
+Parameter weight
+forward(x) Tensor
}
}
namespace tokenize {
class AutoTokenizer {
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
+save_pretrained(save_path)
}
class ChatTemplate {
+String template_str
+render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate
}
}
namespace factory {
class Registry {
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List[str]
}
class BaseFactory {
+Registry _registry
+register(name, category, priority) decorator
+create(name, *args, **kwargs) T
+list_registered() list
}
}
namespace trainer {
class Trainer {
+TrainConfig train_config
+List[TrainCallback] callbacks
+train(checkpoint)
+_get_default_callbacks() List[TrainCallback]
}
class TrainContext {
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+OptimizerProtocol optimizer
+SchedulerProtocol scheduler
+Checkpoint checkpoint
+TrainConfig config
+BaseExecutor executor
+int epoch
+int iteration
+float loss
+DataLoader val_dataloader
+float val_loss
+int world_size
+int rank
+dict kwargs
}
class TrainContextBuilder {
+TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder
+build() TrainContext
}
class BaseStrategy {
+Union[Callable, nn.Module] model
+str device
+compute_loss(batch) Tensor
}
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(train_type, model, device, **kwargs) BaseStrategy
}
class SEQStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class SFTStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class DPOStrategy {
+nn.Module ref_model
+float beta
+str reduction
+compute_loss(batch) Tensor
}
class GRPOStrategy {
+nn.Module ref_model
+float clip_eps
+float kl_coef
+int group_size
+str reduction
+int sync_interval
+compute_loss(batch) Tensor
+sync_ref_model()
}
class BaseScheduler {
+get_lr() List[float]
+step()
}
class SchedulerFactory {
+Registry _registry
+register(name) decorator
+create(optimizer, schedule_type, **kwargs) BaseScheduler
}
class CosineScheduler {
+int warmup_steps
+int lr_decay_steps
+float min_rate
}
class SGDRScheduler {
+int warmup_steps
+int cycle_length
+float min_rate
+int t_mult
}
class TrainCallback {
<<protocol>>
+on_train_begin(context)
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_optimizer_step(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_optimizer_step(context)
}
class GradientCheckpointingCallback {
+tuple modules
+on_train_begin(context)
+on_train_end(context)
}
class CheckpointCallback {
+str save_dir
+int interval
+bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
+save_extra(context)$
+load_extra(extra, context)$
}
class ProgressBarCallback {
+int num_epoch
+int log_interval
+IO file
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
}
class MetricLoggerCallback {
+str log_dir
+int save_interval
+int log_interval
+List[str] metrics
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
}
class ValidationCallback {
+_run_validation(context)
+on_optimizer_step(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
class Muon {
+float lr
+float momentum
+float weight_decay
+int ns_steps
+step(closure) Optional[float]
}
}
namespace inference {
class InferenceEngine {
+nn.Module model
+AutoTokenizer tokenizer
+InferenceScheduler scheduler
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict
+shutdown()
}
class Executor {
+AutoModel model
+AutoTokenizer tokenizer
+KVCache page_cache
+execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int]
}
class InferenceScheduler {
+KVCache _page_cache
+Executor _executor
+TaskManager _task_mgr
+bool _running
+Thread _loop_thread
+int max_seq_len
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
+stop()
+get_stats() Dict
}
class Allocator {
+int _free_mask
+List[int] _refs
+OrderedDict _lru
+alloc() int
+free(idx, keep_cached)
+inc_ref(idx)
+touch(idx)
+ref_count(idx) int
}
class PrefixCache {
+int _page_size
+evict(page_idx)
+has_page(idx) bool
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class PagePool {
-Allocator _alloc
-PrefixCache _prefix
+alloc() int
+free(idx)
+inc_ref(idx)
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class Storage {
+int page_size
+Tensor k_cache
+Tensor v_cache
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
}
class KVCache {
-PagePool _pool
-Storage _storage
-TaskTable _table
+int page_size
+task_alloc(task_id, prompt_ids) bool
+task_free(task_id)
+task_extend(task_id, pos) bool
+task_cached(task_id) int
+task_record_hashes(task_id, prompt_ids, start_logical_page)
+make_table_tensor(task_ids, device) Tensor
+bind(page_table, total_len) KvcacheView
}
class KvcacheView {
-Storage _storage
+Tensor _page_table
+int _total_len
+write(layer_id, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
class TaskTable {
+set(task_id, page_table, cached)
+get(task_id) List[int]
+get_cached(task_id) int
+get_ref(task_id) List[int]
+pop(task_id) Tuple[List[int], int]
+table_tensor(task_ids, device) Tensor
}
class Task {
+str task_id
+List prompt_ids
+int max_tokens
+float temperature
+float top_p
+int top_k
+TaskStatus status
+List output_ids
+int input_tokens
+int output_tokens
+float arrival_time
+float finish_time
+Callable stream_callback
+int next_pos
+is_finished(stop_ids) bool
}
class TaskStatus {
<<enumeration>>
PENDING
RUNNING
FINISHED
ABORTED
}
class TaskManager {
+AutoTokenizer tokenizer
+Deque waiting_queue
+List active_tasks
+add_task(prompt, **kwargs) str
+remove_task(task_id) List[Task]
+remove_finished_tasks(stop_ids) List[Task]
+pull_candidates(n) List[Task]
+activate(task)
+return_to_waiting(tasks)
+get_active_tasks() List[Task]
}
class GenerationRequest {
+List[Dict] messages
+int top_k
+float top_p
+float temperature
+Optional[int] max_tokens
+bool stream
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List[BaseSamplingStrategy] strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
}
class GenerateResult {
+List[Tuple[int, str]] tokens
+List[str] results
+List[bool] _done
+append(token, idx)
+get_results() List[str]
+pop_all() List[Tuple[int, str]]
+wait(timeout) bool
+wait_completion(timeout)
}
class ChatMessage {
+str role
+str content
}
class ChatCompletionRequest {
+str model
+List[ChatMessage] messages
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+Optional[int] max_tokens
+Optional[bool] stream
+Optional[Union[str, List[str]]] stop
+Optional[int] n
+Optional[float] presence_penalty
+Optional[float] frequency_penalty
+Optional[Dict[int, float]] logit_bias
+Optional[str] user
}
class AnthropicMessage {
+str role
+Union[str, List[Dict]] content
}
class MessagesRequest {
+str model
+List[AnthropicMessage] messages
+Optional[str] system
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+int max_tokens
+Optional[bool] stream
+Optional[List[str]] stop_sequences
}
class ResponseBuilder {
<<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler {
+request
+engine
+builder: ResponseBuilder
+handle() Union[StreamingResponse, Dict]
-_handle_stream(agen, ctx, stops) StreamingResponse
-_handle_non_stream(agen, ctx, stops) Dict
}
class StopChecker {
+check(text) Optional[str]
}
class GenContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
}
class app {
<<singleton>>
+FastAPI app
}
}
namespace protocols {
class OptimizerProtocol {
<<protocol>>
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class SchedulerProtocol {
<<protocol>>
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
}
namespace parallel {
class Functions {
<<module>>
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
+only_on_rank(rank, sync) decorator
}
class GradientState {
+int num_steps
+sync_gradients (property) bool
}
class AccumOptimizer {
+Optimizer optimizer
+GradientState gradient_state
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class AccumScheduler {
+LRScheduler scheduler
+GradientState gradient_state
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
class BaseExecutor {
+GradientState gradient_state
+prepare(model, optimizer, dataloader, scheduler) tuple
+accumulate(model) context manager
+backward(loss)
+unwrap_model(model) nn.Module
+sync_gradients (property) bool
+grad_accum_steps (property) int
}
class NoneExecutor {
}
class DDPExecutor {
+_prepare_model(model) nn.Module
+_no_sync(model) context manager
+unwrap_model(model) nn.Module
}
class FSDPExecutor {
+_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module
}
class ExecutorFactory {
+Registry _registry
+register(name) decorator
+create(parallel_mode, **kwargs) BaseExecutor
}
class ParallelModel {
+dist.ProcessGroup process_group
+int rank
+int world_size
}
class ColumnParallelLinear {
+forward(x) Tensor
}
class RowParallelLinear {
+forward(x) Tensor
}
}
%% Relationships — UML notation: <|-- generalization, *-- composition, o-- aggregation, --> association, ..> dependency
%% --- Generalization (inheritance) ---
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- GradientCheckpointingCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
BaseSamplingStrategy <|-- SamplingPipeline
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- AutoRegressiveLM
AutoModel <|-- EmbeddingEncoder
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- AutoRegressiveLMConfig
BaseModelConfig <|-- EncoderConfig
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
BaseFactory <|-- FFNFactory
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor
BaseExecutor <|-- FSDPExecutor
ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager
AutoRegressiveLM *-- DecoderBlock
AutoRegressiveLM *-- RotaryEmbedding
AutoRegressiveLM *-- Embedding
EmbeddingEncoder *-- DecoderBlock
EmbeddingEncoder *-- RotaryEmbedding
EmbeddingEncoder *-- Embedding
DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
AccumScheduler o-- GradientState
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
TrainContext o-- BaseExecutor
KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
SchedulerFactory ..> BaseScheduler : creates
DatasetFactory ..> BaseDataset : creates
CallbackFactory ..> TrainCallback : creates
AttnFactory ..> GQA : creates
AttnFactory ..> MLA : creates
FFNFactory ..> MLP : creates
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
Trainer ..> Functions : spawns
TrainContextBuilder ..> StrategyFactory : uses
TrainContextBuilder ..> ResumableDistributedSampler : creates
Checkpoint ..> Checkpoint : serializes
CheckpointCallback ..> Checkpoint : creates
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
AnthropicResponseBuilder ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> GenContext : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
DPOStrategy --> AutoModel
GRPOStrategy --> AutoModel
InferenceScheduler --> Task
InferenceScheduler --> TaskStatus
Task --> TaskStatus
InferenceEngine --> AutoModel
Executor --> AutoModel
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
Module Overview
| Module | Components | Description |
|---|---|---|
| astrai.config | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| astrai.dataset | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| astrai.serialization | Checkpoint | Model serialization |
| astrai.model | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| astrai.tokenize | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| astrai.trainer | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
| astrai.inference | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
| astrai.parallel | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| astrai.factory | Registry, BaseFactory[T] | Component registration |
| astrai.protocols | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
Design Patterns
| Pattern | Classes | Purpose |
|---|---|---|
| Factory | AttnFactory, FFNFactory, StrategyFactory, DatasetFactory, SchedulerFactory, CallbackFactory, StorageFactory, ConfigFactory, ExecutorFactory |
Decorator-based component creation |
| Registry | BaseFactory, Registry |
Component registration with category/priority |
| Strategy | SEQStrategy, SFTStrategy, DPOStrategy, GRPOStrategy |
Training strategy switching |
| Strategy (Sampling) | TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline |
Composable logit transformations |
| Strategy (API) | ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder |
HTTP API handler with format hooks |
| Builder | TrainContextBuilder |
Chain-building training context |
| Observer | TrainCallback, callback implementations |
Training process monitoring |
| Context | TrainContext |
Unified training state bag |
| Object Pool | Allocator, PagePool |
Page-based KV cache with LRU eviction |
| Executor | BaseExecutor, NoneExecutor, DDPExecutor |
Gradient accumulation & model distribution |
| Storage | BaseStorage, H5Storage, JSONStorage |
Format-agnostic data access |
| Producer-Consumer | InferenceScheduler, Task, queues |
Continuous batching |
| AutoModel Registry | AutoModel, AutoRegressiveLM, EmbeddingEncoder |
Model-type dynamic loading |
Core Relationships
- Config → Training:
TrainConfigholds model, dataset, optimizer_fn, scheduler_fn,parallel_mode,executor_kwargs - Training Flow:
Trainer→TrainContextBuilder→TrainContext, usesBaseStrategyfor loss,BaseExecutorfor gradient accumulation + model distribution - Strategy Selection:
StrategyFactorycreates strategy bytrain_type - Executor Selection:
ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)→NoneExecutor/DDPExecutor/FSDPExecutor - Inference Flow:
InferenceEngine→InferenceScheduler→AutoRegressiveLM, backed byKVCache+SamplingPipeline - Distributed:
spawn_parallel_fn+setup_parallelfor multi-process DDP - Dataset Loading:
DatasetFactorycreates datasets,BaseStorage(H5Storage/JSONStorage) loads viaBaseSegmentFetcher+MultiSegmentFetcher - Checkpoint:
Checkpointsaves/loads safetensors + metadata (rank-0 only), extra state saved as{key}.pt - Scheduler:
SchedulerFactorycreatesCosineScheduler/SGDRScheduler - AutoModel:
from_pretrained()loadsconfig.json+model.safetensors,_disable_random_initreplacesnn.init.*with no-ops - Protocols:
OptimizerProtocol/SchedulerProtocol— structural subtyping forAccumOptimizer/AccumSchedulerwrappers
Document Update Time: 2026-05-24