Compare commits

..

No commits in common. "0a708fff24b32f17a859bc6fcf3ba8f03b644aab" and "65ab69543b4da3afc440a1efd6005bb4cbcfda22" have entirely different histories.

34 changed files with 475 additions and 885 deletions

View File

@ -22,8 +22,7 @@ classDiagram
+int n_layers
+float norm_eps
+int dim_ffn
+Optional[bool] tie_weight
+Optional[dict] rope_scaling
+bool tie_weight
+int max_len
+float rope_theta
+str attn_type
@ -53,7 +52,6 @@ classDiagram
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[dict] rope_scaling
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
@ -82,7 +80,6 @@ classDiagram
+str log_dir
+int log_interval
+List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
@ -107,7 +104,7 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+Optional[Store] storage
+Optional[BaseStorage] storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
@ -129,29 +126,38 @@ classDiagram
+__getitem__(index) Dict
}
class Store {
+Dict[str, List[Tensor]] _data
+Dict[str, List[int]] _cum
+int _length
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
+keys (property)
+load(path, tokenizer)
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
-_fetch_key(key, begin, end) Tensor
-_normalize(raw)
}
class H5Store {
+load(path, tokenizer)
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class JSONStore {
+load(path, tokenizer)
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MmapStore {
+List _mmap_refs
+load(path, tokenizer)
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
@ -159,10 +165,10 @@ classDiagram
+int iter
}
class StoreFactory {
class StorageFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) Store
+create(storage_type) BaseStorage
}
class DatasetFactory {
@ -451,15 +457,16 @@ classDiagram
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_optimizer_step(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_optimizer_step(context)
+on_step_begin(context)
}
class GradientCheckpointingCallback {
@ -505,7 +512,7 @@ classDiagram
class ValidationCallback {
+_run_validation(context)
+on_optimizer_step(context)
+on_step_end(context)
}
class CallbackFactory {
@ -740,58 +747,56 @@ classDiagram
+str model
+List[AnthropicMessage] messages
+Optional[str] system
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+float temperature
+float top_p
+int top_k
+int max_tokens
+Optional[bool] stream
+bool stream
+Optional[List[str]] stop_sequences
}
class ResponseBuilder {
<<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler {
<<abstract>>
+request
+engine
+builder: ResponseBuilder
+build_prompt() str
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict]
-_handle_stream(agen, ctx, stops) StreamingResponse
-_handle_non_stream(agen, ctx, stops) Dict
}
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
}
class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str]
+trim(text, matched) str
}
class GenContext {
class StreamContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
}
class app {
@ -871,11 +876,6 @@ classDiagram
+unwrap_model(model) nn.Module
}
class FSDPExecutor {
+_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module
}
class ExecutorFactory {
+Registry _registry
+register(name) decorator
@ -911,14 +911,12 @@ classDiagram
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
Store <|-- H5Store
Store <|-- JSONStore
Store <|-- MmapStore
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
@ -938,19 +936,20 @@ classDiagram
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StoreFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor
BaseExecutor <|-- FSDPExecutor
ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
@ -964,6 +963,7 @@ classDiagram
DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
@ -971,9 +971,6 @@ classDiagram
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
@ -981,7 +978,7 @@ classDiagram
TrainContext o-- BaseExecutor
KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- Store
BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
@ -995,14 +992,12 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StoreFactory ..> H5Store : creates
StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
@ -1014,10 +1009,10 @@ classDiagram
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
AnthropicResponseBuilder ..> MessagesRequest : receives
OpenAIHandler ..> ChatCompletionRequest : receives
AnthropicHandler ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> GenContext : creates
ProtocolHandler ..> StreamContext : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
@ -1030,6 +1025,8 @@ classDiagram
Executor --> AutoModel
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
```
@ -1039,13 +1036,13 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
@ -1053,17 +1050,17 @@ classDiagram
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
| **Builder** | `TrainContextBuilder` | Chain-building training context |
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1072,10 +1069,10 @@ classDiagram
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops

View File

@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview
```
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
```
## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
StoreFactory.create("h5") → H5Store
StoreFactory.create("json") → JSONStore
StoreFactory.create("bin") → MmapStore
StorageFactory.create("h5") → H5Storage
StorageFactory.create("json") → JSONStorage
```
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
Both support shared memory via `.share_memory_()`.
## Data Keys by Training Type
@ -34,14 +33,14 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us
## Dataset Architecture
```
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
→ StoreFactory.create(detect_format(path))
Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
DatasetFactory.load(train_type, path, window_size, stride)
→ StorageFactory.create(detect_format(path))
MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx)
```
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
`window_size` = max input length, `stride` = step between consecutive samples.
## Sampler

View File

@ -46,22 +46,20 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Strategy Pattern)
## Protocol Handlers (Template Method)
```python
class ProtocolHandler: # concrete orchestrator
def handle(self, request):
prompt, ctx, stops = builder.prepare(request, engine)
class ProtocolHandler(ABC):
def handle(self):
ctx = StreamContext(...)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx, stops)
if stream: self._handle_stream(agen, ctx)
else: self._handle_non_stream(agen, ctx)
```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
`OpenAIHandler``/v1/chat/completions`, `AnthropicHandler``/v1/messages`.
## Engine & GenerateResult
@ -118,7 +116,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `temperature` | float | 1.0 | Sampling temperature (0.02.0) |
| `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length |

View File

@ -53,7 +53,7 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |

View File

@ -82,7 +82,8 @@ on_train_begin
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
scheduler.step()
scheduler.step() # called every iteration
on_epoch_end
on_train_end
```
@ -189,7 +190,7 @@ context = (
```
- Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)`

View File

@ -17,8 +17,8 @@ def required(**kw):
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model factory for training.")
model: nn.Module = field(
default=None, metadata=required(help="Model for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(

View File

@ -4,17 +4,15 @@ from astrai.dataset.dataset import (
)
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
H5Store,
JSONStore,
MmapStore,
Store,
StoreFactory,
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MultiSegmentFetcher,
StorageFactory,
detect_format,
json_to_bin,
load_bin,
load_h5,
load_json,
save_bin,
save_h5,
save_json,
)
@ -22,18 +20,16 @@ from astrai.dataset.storage import (
__all__ = [
"BaseDataset",
"DatasetFactory",
"Store",
"StoreFactory",
"H5Store",
"JSONStore",
"MmapStore",
"BaseSegmentFetcher",
"MultiSegmentFetcher",
"BaseStorage",
"H5Storage",
"JSONStorage",
"StorageFactory",
"detect_format",
"save_h5",
"load_h5",
"save_json",
"load_json",
"save_bin",
"load_bin",
"json_to_bin",
"ResumableDistributedSampler",
]

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset
from astrai.dataset.storage import (
Store,
StoreFactory,
BaseStorage,
StorageFactory,
detect_format,
)
from astrai.factory import BaseFactory
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
super().__init__()
self.window_size = window_size
self.stride = stride
self.storage: Optional[Store] = None
self.storage: Optional[BaseStorage] = None
@property
def required_keys(self) -> List[str]:
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StoreFactory.create(storage_type)
self.storage = StorageFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys()
@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
"""
@classmethod
def _validate_component(cls, dataset_cls: type):
def _validate_component(cls, dataset_cls: type) -> None:
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")

View File

@ -1,20 +1,7 @@
"""Storage backends for different data formats.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
"""
import bisect
@ -25,7 +12,6 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import h5py
import numpy as np
import torch
from torch import Tensor
@ -118,38 +104,6 @@ def load_json(
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json"))
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json"))
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
@ -157,7 +111,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path
Returns:
Format string ("h5", "bin", or "json")
Format string ("h5" or "json")
Raises:
FileNotFoundError: If no supported data files are found
@ -174,118 +128,166 @@ def detect_format(load_path: str) -> str:
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
"""
def __init__(self):
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
self._fetcher: Optional[MultiSegmentFetcher] = None
@abstractmethod
def load(self, path: str, tokenizer=None) -> None:
def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@StoreFactory.register("h5")
class H5Store(Store):
@StorageFactory.register("h5")
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str, tokenizer=None):
self._normalize(load_h5(path))
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
@StoreFactory.register("json")
class JSONStore(Store):
@StorageFactory.register("json")
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
Supports two modes:
@ -294,28 +296,6 @@ class JSONStore(Store):
callable (str -> List[int]) at load time.
"""
def load(self, path: str, tokenizer=None):
self._normalize(load_json(path, tokenizer=tokenizer))
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(path)
self._normalize(raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)

View File

@ -23,7 +23,7 @@ class Registry:
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
):
) -> None:
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]):
def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.

View File

@ -42,7 +42,7 @@ class Allocator:
return idx
return -1
def free(self, idx: int, keep_cached: bool = False):
def free(self, idx: int, keep_cached: bool = False) -> None:
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int):
def inc_ref(self, idx: int) -> None:
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock:
return self._refs[idx]
def touch(self, idx: int):
def touch(self, idx: int) -> None:
with self._lock:
self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int):
def evict(self, idx: int) -> None:
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
@ -96,7 +96,9 @@ class PrefixCache:
hits.append(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
@ -125,13 +127,13 @@ class PagePool:
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int):
def free(self, idx: int) -> None:
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int):
def inc_ref(self, idx: int) -> None:
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
@ -140,7 +142,9 @@ class PagePool:
self._alloc.touch(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -153,7 +157,7 @@ class TaskTable:
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int):
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
@ -216,7 +220,7 @@ class Storage:
start_pos: int,
k: Tensor,
v: Tensor,
):
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
@ -282,7 +286,7 @@ class KvcacheView:
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor):
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -335,7 +339,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str):
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
@ -355,7 +359,7 @@ class KVCache:
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
):
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):

View File

@ -29,7 +29,9 @@ class Executor:
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len:
return

View File

@ -75,14 +75,14 @@ class InferenceScheduler:
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str):
def remove_task(self, task_id: str) -> None:
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self):
def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
@ -186,14 +186,14 @@ class InferenceScheduler:
self._task_mgr.clear_queues()
raise
def start(self):
def start(self) -> None:
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self):
def stop(self) -> None:
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft())
return to_add
def activate(self, task: Task):
def activate(self, task: Task) -> None:
task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]):
def return_to_waiting(self, tasks: List[Task]) -> None:
with self._lock:
for task in reversed(tasks):
self.waiting_queue.appendleft(task)
@ -185,7 +185,7 @@ class TaskManager:
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0):
def wait_for_tasks(self, timeout: float = 1.0) -> None:
self._task_event.clear()
self._task_event.wait(timeout=timeout)
@ -197,10 +197,10 @@ class TaskManager:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self):
def clear_queues(self) -> None:
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self):
def wake(self) -> None:
self._task_event.set()

View File

@ -48,7 +48,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0):
def wait_completion(self, timeout: float = 300.0) -> None:
with self._cond:
if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
@ -281,7 +281,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats()
def shutdown(self):
def shutdown(self) -> None:
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -15,11 +15,7 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
@contextmanager
def _disable_random_init(enable: bool = True):
if not enable:
yield
return
names = (
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
@ -29,15 +25,18 @@ def _disable_random_init(enable: bool = True):
"constant_",
"normal_",
"uniform_",
)
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
for n in orig:
setattr(nn.init, n, lambda *a, **kw: None)
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for n, fn in orig.items():
setattr(nn.init, n, fn)
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class AutoModel(BaseFactory["AutoModel"], nn.Module):
@ -83,7 +82,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained(
self,
save_directory: Union[str, Path],
):
) -> None:
save_model(
config=self.config.to_dict(),
state_dict=self.state_dict(),

View File

@ -68,6 +68,9 @@ class EmbeddingEncoder(AutoModel):
x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Mapping, Optional
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
@ -136,7 +136,7 @@ class AutoRegressiveLM(AutoModel):
input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
) -> Tensor:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)

View File

@ -203,45 +203,9 @@ class DDPExecutor(BaseExecutor):
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None,
device_mesh=None,
):
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._fsdp_kwargs = fsdp_kwargs
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:

View File

@ -1,9 +1,8 @@
import io
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict
import safetensors.torch as st
import torch
@ -12,11 +11,11 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
_MODEL_CONFIG_FILE = "config.json"
def save_safetensors(state_dict: dict, path: str | Path):
def save_safetensors(state_dict: dict, path: str | Path) -> None:
st.save_file(state_dict, str(path))
@ -24,7 +23,7 @@ def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path))
def save_json(data: dict, path: str | Path):
def save_json(data: dict, path: str | Path) -> None:
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
@ -34,92 +33,13 @@ def load_json(path: str | Path) -> dict:
return json.load(f)
def save_torch(obj: Any, path: str | Path):
def save_torch(obj: Any, path: str | Path) -> None:
torch.save(obj, str(path))
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
def load_torch(path: str | Path) -> Any:
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict:
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs: List[Tuple[str, List[int], str]] = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass
class Checkpoint:
@ -129,7 +49,7 @@ class Checkpoint:
extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str):
def save(self, save_dir: str) -> None:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
@ -148,16 +68,24 @@ class Checkpoint:
save_torch(value, save_path / f"{key}.pt")
@classmethod
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
def load(cls, save_dir: str) -> "Checkpoint":
save_path = Path(save_dir)
meta = _get_meta(save_path)
state_dict = _load_state_dict(save_path, broadcast=broadcast)
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
extra = {}
for f in sorted(save_path.iterdir()):
for f in save_path.iterdir():
if f.suffix == ".pt":
extra[f.stem] = load_torch(f, broadcast=broadcast)
extra[f.stem] = load_torch(f)
return cls(
state_dict=state_dict,
@ -165,3 +93,18 @@ class Checkpoint:
iteration=meta.get("iteration", 0),
extra=extra,
)
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _MODEL_CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
"""
@classmethod
def _validate_component(cls, strategy_cls: type):
def _validate_component(cls, strategy_cls: type) -> None:
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device, get_rank
from astrai.parallel.setup import get_current_device
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
@ -139,27 +139,27 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext):
# All ranks gather state_dict — collective for FSDP, local for DDP
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
@ -168,7 +168,13 @@ class CheckpointCallback(TrainCallback):
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
@ -190,6 +196,12 @@ class CheckpointCallback(TrainCallback):
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self
import torch.nn as nn
@ -11,7 +10,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_model_weights
from astrai.serialization import Checkpoint
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -43,10 +42,10 @@ class TrainContextBuilder:
config: TrainConfig,
):
self.config = config
self._resume_dir: Optional[str] = None
self._checkpoint: Optional[Checkpoint] = None
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._resume_dir = resume_dir
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
@ -59,40 +58,36 @@ class TrainContextBuilder:
**cfg.executor_kwargs,
)
model = cfg.model_fn()
model = model.to(device=device)
context = TrainContext(
model=model,
model=cfg.model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
executor=executor,
)
if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict
context.model = context.model.to(device=device)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
if self._checkpoint.state_dict:
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:
checkpoint = None
state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
context.iteration = max(checkpoint.iteration, cfg.start_batch)
context.checkpoint = checkpoint
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
)
if cfg.lora is not None:
inject_lora(
model,
context.model,
r=cfg.lora.r,
alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules),
)
context.optimizer = cfg.optimizer_fn(model)
context.optimizer = cfg.optimizer_fn(context.model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
sampler_offset = context.iteration * cfg.batch_per_device
@ -130,21 +125,13 @@ class TrainContextBuilder:
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
model,
context.model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
model=context.model,
train_type=cfg.strategy,

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
@ -53,9 +54,9 @@ class Trainer:
if method:
method(context)
def _trainer_loop(self, resume_dir: Optional[str] = None):
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
context = (
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
executor = context.executor
self._call_callbacks("on_train_begin", context)
@ -89,13 +90,13 @@ class Trainer:
self._call_callbacks("on_epoch_end", context)
except Exception as e:
logger.error("Training failed: %s", str(e), exc_info=True)
logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, resume_dir: Optional[str] = None):
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
@ -105,5 +106,5 @@ class Trainer:
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
resume_dir=resume_dir,
checkpoint=checkpoint,
)

View File

@ -1,279 +0,0 @@
"""MMLU evaluation via log-likelihood ranking."""
import argparse
import csv
import json
import os
import shutil
import urllib.request
import zipfile
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _download_and_extract(url: str, data_dir: str):
zip_path = os.path.join(data_dir, "mmlu.zip")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
urllib.request.urlretrieve(url, zip_path)
print("Extracting...")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(data_dir)
os.remove(zip_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "test-master", "data")
if os.path.exists(src):
for item in os.listdir(src):
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
shutil.rmtree(os.path.join(data_dir, "test-master"))
print(f"MMLU data saved to {data_dir}")
def _strip_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :].strip()
return text
def load_csv(path: str) -> list[dict]:
data = []
with open(path, "r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 6:
continue
if row[0].strip().lower() == "question":
continue
data.append(
{
"question": row[0].strip(),
"A": _strip_prefix(row[1].strip(), "A)"),
"B": _strip_prefix(row[2].strip(), "B)"),
"C": _strip_prefix(row[3].strip(), "C)"),
"D": _strip_prefix(row[4].strip(), "D)"),
"answer": row[5].strip(),
}
)
return data
def build_prompt(
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
) -> str:
prompt = ""
if n_shot > 0 and dev_data:
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
for item in dev_data[:n_shot]:
prompt += f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {item[k]}\n"
prompt += f"Answer: {item['answer']}\n\n"
prompt += f"Question: {question}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {choices[k]}\n"
prompt += "Answer:"
return prompt
def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float:
choice_text = f" {choice_letter}"
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids
max_len = model.config.max_len
if len(input_ids) > max_len:
overflow = len(input_ids) - max_len
input_ids = input_ids[overflow:]
ctx_len = len(input_ids) - len(choice_ids)
else:
ctx_len = len(context_ids)
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits = model(input_tensor)["logits"][0]
score = 0.0
for i, tid in enumerate(choice_ids):
pos = ctx_len - 1 + i
if pos >= len(logits):
break
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
return score
def evaluate_subject(
model,
tokenizer,
subject: str,
test_data: list[dict],
dev_data: list[dict] | None,
device: str,
n_shot: int,
) -> tuple[float, int, int]:
correct = 0
total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or [])
context_ids = tokenizer.encode(prompt)
scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D")
}
if max(scores, key=scores.get) == item["answer"]:
correct += 1
total += 1
return correct / total, correct, total
def main():
parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
)
parser.add_argument("--download", action="store_true", help="Download MMLU data")
parser.add_argument(
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
)
parser.add_argument(
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
)
parser.add_argument("--output", type=str, help="Output JSON path")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16" if torch.cuda.is_available() else "float32",
help="Torch dtype",
)
args = parser.parse_args()
if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
subjects = args.subjects or MMLU_SUBJECTS
results = {}
total_correct = 0
total_questions = 0
for subject in subjects:
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
test_path = os.path.join(
args.data_dir, args.split, f"{subject}_{args.split}.csv"
)
if not os.path.exists(test_path):
print(f" Skipping {subject}: test file not found")
continue
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
test_data = load_csv(test_path)
acc, corr, tot = evaluate_subject(
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
)
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
total_correct += corr
total_questions += tot
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
overall = total_correct / total_questions if total_questions else 0
print(f"\n{'=' * 70}")
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
results["_overall"] = {
"accuracy": round(overall, 4),
"correct": total_correct,
"total": total_questions,
}
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file(
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
):
# Load model and tokenizer
model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f:
@ -44,8 +44,8 @@ def process_file(
for seq in batch_encoded:
pad_len = max_len - len(seq)
padded_seq = seq + [tokenizer.pad_id] * pad_len
mask = [True] * len(seq) + [False] * pad_len
padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [False] * pad_len + [True] * len(seq)
padded_ids.append(padded_seq)
masks.append(mask)
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument(
"--param_path", type=str, required=True, help="Path to the model directory."
"--model_dir", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file."

View File

@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser.add_argument(
"--param_path",
"--param-path",
type=Path,
default=None,
help="Path to model parameters (default: project_root/params)",

View File

@ -8,6 +8,7 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer
@ -146,8 +147,8 @@ def parse_args() -> argparse.Namespace:
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp", "fsdp"],
help="Parallel training strategy (none, ddp, fsdp).",
choices=["none", "ddp"],
help="Parallel training strategy.",
)
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
@ -165,10 +166,6 @@ def parse_args() -> argparse.Namespace:
return args
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -231,8 +228,6 @@ def train(
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config
config_path = os.path.join(param_path, "config.json")
@ -241,6 +236,15 @@ def train(
if window_size is None:
window_size = config.max_len
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
# (model weights already loaded into model above)
checkpoint.state_dict = {}
strategy_kwargs = {
"beta": dpo_beta,
"label_smoothing": label_smoothing,
@ -255,7 +259,6 @@ def train(
"broadcast_buffers": False,
}
model_fn = partial(create_model, config)
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
@ -287,7 +290,7 @@ def train(
)
train_config = TrainConfig(
model_fn=model_fn,
model=model,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
@ -312,7 +315,7 @@ def train(
)
trainer = Trainer(train_config)
trainer.train(resume_dir=param_path)
trainer.train(checkpoint=checkpoint)
if __name__ == "__main__":

View File

@ -7,8 +7,10 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
H5Store,
StoreFactory,
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
StorageFactory,
detect_format,
load_json,
save_h5,
@ -316,48 +318,37 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0
def test_store_unloaded_len():
"""Unloaded Store has __len__ == 0"""
store = H5Store()
assert len(store) == 0
assert store.keys == []
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_store_fetch_begin_equals_end(base_test_env):
"""Store.fetch with begin == end returns empty tensor"""
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
result = dataset.storage.fetch(10, 10, "sequence")
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0
def test_store_empty_data_len(base_test_env):
"""Store loaded with empty data has __len__ == 0"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError"""
store = H5Store()
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
with pytest.raises(RuntimeError, match="not loaded"):
store.fetch(0, 10, "sequence")
storage.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
@ -376,10 +367,10 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path)
def test_create_store_invalid_type():
"""StoreFactory.create raises ValueError for unknown type"""
def test_create_storage_invalid_type():
"""StorageFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StoreFactory.create("parquet")
StorageFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
@ -416,23 +407,14 @@ def test_load_json_skips_config_file(base_test_env):
assert len(result["sequence"]) == 1
def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
save_h5(data_dir, "data", {"sequence": segs})
store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
assert result.tolist() == [3, 4, 5, 6, 7]

View File

@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
def create_train_config(
model_fn,
model: torch.nn.Module,
dataset: Dataset,
test_dir: str,
device: str,
@ -43,7 +43,7 @@ def create_train_config(
"""Factory function to create common TrainConfig for tests.
Args:
model_fn: Model factory (callable returning nn.Module)
model: The model to train
dataset: Training dataset
test_dir: Checkpoint directory
device: Device type ("cuda" or "cpu")
@ -70,7 +70,7 @@ def create_train_config(
return TrainConfig(
strategy=strategy,
model_fn=model_fn,
model=model,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,

View File

@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
)
train_config = TrainConfig(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
)
train_config = TrainConfig(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,

View File

@ -4,6 +4,7 @@ import numpy as np
import torch
from astrai.config.train_config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer
@ -23,7 +24,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
strategy="seq",
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
@ -38,20 +39,17 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
trainer = Trainer(train_config)
# Should handle early stopping gracefully
checkpoint = None
try:
trainer.train()
checkpoint = trainer.train()
except Exception:
# Handle any exceptions
pass
# Resume from latest checkpoint
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
trainer = Trainer(train_config)
trainer.train(resume_dir=load_dir)
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
# Verify checkpoint was saved at expected iteration
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
import json
with open(os.path.join(load_dir, "meta.json")) as f:
meta = json.load(f)
assert meta["iteration"] == 10
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10

View File

@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
for batch_per_device in batch_sizes:
train_config = train_config_factory(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
for config in small_batch_configs:
train_config = train_config_factory(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],