10 KiB
10 KiB
AstrAI Data Flow Documentation
This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
Overview
AstrAI adopts a modular design with the following main components:
- Dataset Module (
astrai/dataset/): Dataset, sampler, storage backends - Model Module (
astrai/model/): AutoModel, Transformer model and its submodules - Training Module (
astrai/trainer/): Trainer, training context, strategies, schedulers, callbacks, metric utilities - Inference Module (
astrai/inference/): Inference engine with continuous batching, streaming generation - Config Module (
astrai/config/): ModelConfig, TrainConfig - Factory Module (
astrai/factory/): Registry, BaseFactory for component registration - Parallel Module (
astrai/parallel/): Distributed training support - Serialization (
astrai/serialization.py): Checkpoint management with safetensors
Data Flow Diagram
flowchart LR
subgraph A[Data Preparation]
direction TB
A1[Raw Text] --> A2[AutoTokenizer]
A2 --> A3[Tokenized .h5 files]
A3 --> A4[BaseDataset]
A4 --> A5[ResumableDistributedSampler]
A5 --> A6[DataLoader]
end
subgraph B[Training]
direction TB
B1[DataLoader] --> B2[BaseStrategy]
B2 --> B3[Transformer Forward]
B3 --> B4[Loss + Backward]
B4 --> B5[Gradient Accumulation]
B5 -->|every accum_steps| B6[Optimizer Step]
B6 --> B7[LR Scheduler]
B7 -->|next batch| B2
B6 --> B8[CheckpointCallback]
end
subgraph C[Inference]
direction TB
C1[Checkpoint] --> C2[AutoModel]
C1 --> C3[AutoTokenizer]
C2 --> C4[InferenceEngine]
C3 --> C4
C4 --> C5[InferenceScheduler]
C5 --> C6[Transformer Forward]
C6 --> C7[sample]
C7 --> C8{End?}
C8 -->|No| C6
C8 -->|Yes| C9[Generated Text]
end
A --> B
B --> C
Detailed Module Descriptions
1. Data Serialization (astrai/dataset/storage.py & astrai/serialization.py)
save_h5: Saves tensors by groups as HDF5 files (.h5), each key maps to a list of tensorsload_h5: Loads.h5files, returnsDict[str, List[Tensor]], supports shared memoryCheckpoint: Encapsulates model state dict + epoch + iteration; uses safetensors
2. Dataset Module
2.1 Dataset (dataset.py)
BaseDataset: Abstract base class for windowed sequence samplingBaseSegmentFetcher/MultiSegmentFetcher: Fetch tensor segments by index rangeDatasetFactory: Creates dataset instances bytrain_type(seq,sft,dpo,grpo)- Data keys:
"sequence"(SEQ),"loss_mask"(SFT),"chosen"/"rejected"(DPO),"prompts"/"responses"/"masks"/"rewards"(GRPO) - Storage backends: HDF5 (
.h5) or JSON (.json/.jsonl), auto-detected bydetect_format()
2.2 Sampler (sampler.py)
ResumableDistributedSampler: Tracksepochanditerfor breakpoint resume; supports shuffle and drop_last
3. Model Module
3.1 Transformer / AutoModel
AutoModel: Base class withfrom_pretrained()/save_pretrained()Transformer: Decoder-only architecture, registered via@AutoModel.register('transformer')- Embedding → N×DecoderBlock → RMSNorm → Linear lm_head
- RoPE position encoding, optional weight tying
3.2 Submodules (module.py)
DecoderBlock: Pre-LN (norm→attention→residual, norm→MLP→residual), usesAttnFactory/FFNFactoryGQA: Grouped Query Attention (q_heads ÷ kv_heads = n_rep)MLA: Multi-head Latent Attention with KV compression (kv_lora_rank)MLP:SiLU(gate(x)) * up(x)→ down projectionRotaryEmbedding: RoPE complex cache (freqs_cis)RMSNorm: Layer normalization
4. Training Module
4.1 Training Context (train_context.py)
TrainContext: Dataclass holding model, optimizer, dataloader, strategy, scheduler, checkpoint stateTrainContextBuilder: Builder pattern — takes checkpoint for resume, builds all components
4.2 Trainer (trainer.py)
The training loop is nested: epoch → batch (with step phase interspersed):
on_train_begin
on_epoch_begin
for each accumulation window of batches: ← step phase
on_step_begin
for each batch in window: ← batch phase
on_batch_begin → strategy(batch) → loss
(loss / window_size).backward() → on_batch_end
iteration += 1
on_step_end
optimizer.step() → zero_grad
scheduler.step() ← per step, not per batch
on_epoch_end
on_train_end
Key points:
on_step_*fires everyaccumulation_stepsbatches, wrapping optimizer step AFTER the hookon_batch_*fires every batch, wrapping loss computationGradientClippingCallbackfires onon_step_end- LR scheduler steps inline (no
SchedulerCallbackclass), once per optimizer step
4.3 Strategy (strategy.py)
SEQStrategy: Next-token prediction, cross-entropy with label smoothingSFTStrategy: Supervised fine-tuning with loss maskingDPOStrategy: Direct Preference Optimization with reference modelGRPOStrategy: Group Relative Policy Optimization with clipped ratio
4.4 Scheduler (schedule.py)
CosineScheduler: Cosine decay + linear warmupSGDRScheduler: Cosine annealing with warm restarts- Created by
SchedulerFactoryand bound to optimizer
4.5 Callbacks
CheckpointCallback: Saves safetensors atckpt_intervaliterationsProgressBarCallback: tqdm progress displayMetricLoggerCallback: Writes JSONL metrics to{ckpt_dir}/logs/GradientClippingCallback:clip_grad_norm_onon_step_end
5. Inference Module
5.1 Inference Engine (engine.py)
InferenceEngine: Facade over scheduler; providesgenerate(),generate_with_request(),generate_async()- Accepts
prompt: str | List[str], returns generator (stream) or string (non-stream)
5.2 Scheduler 4-Phase Loop (scheduler.py)
Background thread runs continuously:
1. Cleanup → Remove finished tasks, free KV cache pages
2. Refill → Pop from waiting_queue, alloc pages, add to active
3. Prefill → Group active tasks by prompt_len, run full forward pass
4. Decode → Pick largest same-position group, run single-token forward
Task: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)KVCache: Facade overAllocator+PrefixCache+PagePool+Storagefor paged KV cacheKvcacheView: Batch view bundling cache + page table for attention layerssample(): Temperature → top-k → top-p → multinomial
5.3 Server (server.py)
- FastAPI with OpenAI
/v1/chat/completionsand Anthropic/v1/messagesendpoints - Streaming via SSE, health check at
/health, stats at/stats
6. Tokenizer Module
AutoTokenizer: Wraps HuggingFacetokenizers.Tokenizer(nottransformers);encode/decode/apply_chat_templateChatTemplate: Jinja2-based template rendering for multi-turn chat
7. Factory & Parallel
Registry/BaseFactory: Decorator-based component registrationspawn_parallel_fn: Multi-process DDP launcher with NCCL backendParallelModel/ColumnParallelLinear/RowParallelLinear: Tensor model parallelism
Training Data Flow — Detailed Steps
-
Data Preparation
- Raw text → token IDs via
AutoTokenizer.encode() - Save as
.h5files (groups of tensor lists per data key)
- Raw text → token IDs via
-
Dataset Loading
BaseDataset.load()callsload_h5(), buildsMultiSegmentFetcher- Sliding window of
window_sizewithstridedetermines sample boundaries
-
Sampling & Batching
ResumableDistributedSamplerproduces shuffled index sequencesDataLoaderfetches[batch_size, window_size]tensors via__getitem__
-
Strategy Forward
- Strategy receives batch, calls
Transformer.forward()for logits - Computes task-specific loss (cross-entropy, DPO, GRPO)
- Strategy receives batch, calls
-
Backward & Accumulation
stand_loss = loss / step_batch_nums(divide by actual batch count in this window)stand_loss.backward()accumulates gradients- Every
accumulation_stepsbatches:optimizer.step()→zero_grad() - Every optimizer step:
scheduler.step()updates learning rate
-
Checkpoint
CheckpointCallbacksavesmodel.state_dict()+ metadata to safetensors atckpt_intervaliterations- Does NOT save optimizer/scheduler state by default;
Checkpoint.extraorsave_extra_fncan store arbitrary additional data
Inference Data Flow — Detailed Steps
-
Model Loading
AutoModel.from_pretrained(path)loads weights from safetensorstorch.inference_mode()wraps generation
-
Prompt Construction
- Messages →
apply_chat_template(messages, tokenize=False)→ prompt string tokenizer.encode(prompt)→ token IDs (truncated tomax_prompt_len)
- Messages →
-
Continuous Batching Loop
- Cleanup: Finished tasks →
stream_callback(STOP), free KV pages - Refill: Pop from waiting queue,
PagePool.task_alloc()for prompt pages - Prefill: Group by prompt length, run full forward with
start_pos=0 - Decode: Pick position group with most tasks, single-token forward:
- Model forward →
logits→sample()→ next token ID - Append to
output_ids, updateoutput_tokens PagePool.task_alloc()allocates pages as neededstream_callback(token)for streaming clients
- Model forward →
- Cleanup: Finished tasks →
-
Output
tokenizer.decode(output_ids)→ text- Return to caller (streaming: token-by-token; non-streaming: complete string)
Checkpoint & Serialization
- Training Checkpoint: safetensors weights + epoch/iteration metadata + optional extras. Optimizer/scheduler state is NOT persisted by default but can be stored via
extra. - Inference Loading:
AutoModel.from_pretrained()loads from the same safetensors format. - Dataset Serialization: HDF5 with shared memory support for large-scale pre-training data.
Document Update Time: 2026-05-15