Compare commits

..

No commits in common. "0a708fff24b32f17a859bc6fcf3ba8f03b644aab" and "65ab69543b4da3afc440a1efd6005bb4cbcfda22" have entirely different histories.

34 changed files with 475 additions and 885 deletions

View File

@ -22,8 +22,7 @@ classDiagram
+int n_layers +int n_layers
+float norm_eps +float norm_eps
+int dim_ffn +int dim_ffn
+Optional[bool] tie_weight +bool tie_weight
+Optional[dict] rope_scaling
+int max_len +int max_len
+float rope_theta +float rope_theta
+str attn_type +str attn_type
@ -53,7 +52,6 @@ classDiagram
+int n_kv_heads +int n_kv_heads
+bool use_qk_norm +bool use_qk_norm
+bool use_gated_attention +bool use_gated_attention
+Optional[dict] rope_scaling
+Optional[str] pooling_type +Optional[str] pooling_type
+Optional[bool] normalize_embeddings +Optional[bool] normalize_embeddings
} }
@ -82,7 +80,6 @@ classDiagram
+str log_dir +str log_dir
+int log_interval +int log_interval
+List[str] metrics +List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed +int random_seed
+int num_workers +int num_workers
+Optional[int] prefetch_factor +Optional[int] prefetch_factor
@ -107,7 +104,7 @@ classDiagram
class BaseDataset { class BaseDataset {
+int window_size +int window_size
+int stride +int stride
+Optional[Store] storage +Optional[BaseStorage] storage
+load(load_path, storage_type, tokenizer) +load(load_path, storage_type, tokenizer)
+__getitem__(index) +__getitem__(index)
+__len__() +__len__()
@ -129,29 +126,38 @@ classDiagram
+__getitem__(index) Dict +__getitem__(index) Dict
} }
class Store { class BaseSegmentFetcher {
+Dict[str, List[Tensor]] _data +List[Tensor] segments
+Dict[str, List[int]] _cum +List[int] cum_lengths
+int _length +int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
+keys (property) +keys (property)
+load(path, tokenizer) +load(load_path, tokenizer)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
-_fetch_key(key, begin, end) Tensor
-_normalize(raw)
} }
class H5Store { class H5Storage {
+load(path, tokenizer) +load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
} }
class JSONStore { class JSONStorage {
+load(path, tokenizer) +load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
} }
class MmapStore { class MultiSegmentFetcher {
+List _mmap_refs +Dict multi_fetchers
+load(path, tokenizer) +List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
} }
class ResumableDistributedSampler { class ResumableDistributedSampler {
@ -159,10 +165,10 @@ classDiagram
+int iter +int iter
} }
class StoreFactory { class StorageFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
+create(storage_type) Store +create(storage_type) BaseStorage
} }
class DatasetFactory { class DatasetFactory {
@ -451,15 +457,16 @@ classDiagram
+on_train_end(context) +on_train_end(context)
+on_epoch_begin(context) +on_epoch_begin(context)
+on_epoch_end(context) +on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context) +on_batch_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_optimizer_step(context)
+on_error(context) +on_error(context)
} }
class GradientClippingCallback { class GradientClippingCallback {
+float max_grad_norm +float max_grad_norm
+on_optimizer_step(context) +on_step_begin(context)
} }
class GradientCheckpointingCallback { class GradientCheckpointingCallback {
@ -505,7 +512,7 @@ classDiagram
class ValidationCallback { class ValidationCallback {
+_run_validation(context) +_run_validation(context)
+on_optimizer_step(context) +on_step_end(context)
} }
class CallbackFactory { class CallbackFactory {
@ -740,58 +747,56 @@ classDiagram
+str model +str model
+List[AnthropicMessage] messages +List[AnthropicMessage] messages
+Optional[str] system +Optional[str] system
+Optional[float] temperature +float temperature
+Optional[float] top_p +float top_p
+Optional[int] top_k +int top_k
+int max_tokens +int max_tokens
+Optional[bool] stream +bool stream
+Optional[List[str]] stop_sequences +Optional[List[str]] stop_sequences
} }
class ResponseBuilder {
<<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler { class ProtocolHandler {
<<abstract>>
+request +request
+engine +engine
+builder: ResponseBuilder +build_prompt() str
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict] +handle() Union[StreamingResponse, Dict]
-_handle_stream(agen, ctx, stops) StreamingResponse }
-_handle_non_stream(agen, ctx, stops) Dict
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
} }
class StopChecker { class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str] +check(text) Optional[str]
+trim(text, matched) str
} }
class GenContext { class StreamContext {
+str resp_id +str resp_id
+int created +int created
+str model +str model
+int prompt_tokens +int prompt_tokens
+int completion_tokens +int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
} }
class app { class app {
@ -871,11 +876,6 @@ classDiagram
+unwrap_model(model) nn.Module +unwrap_model(model) nn.Module
} }
class FSDPExecutor {
+_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module
}
class ExecutorFactory { class ExecutorFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
@ -911,14 +911,12 @@ classDiagram
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset BaseDataset <|-- GRPODataset
Store <|-- H5Store BaseStorage <|-- H5Storage
Store <|-- JSONStore BaseStorage <|-- JSONStorage
Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy BaseSamplingStrategy <|-- TopPStrategy
@ -938,19 +936,20 @@ classDiagram
BaseFactory <|-- StrategyFactory BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory BaseFactory <|-- CallbackFactory
BaseFactory <|-- StoreFactory BaseFactory <|-- StorageFactory
BaseFactory <|-- ExecutorFactory BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor BaseExecutor <|-- DDPExecutor
BaseExecutor <|-- FSDPExecutor ProtocolHandler <|-- OpenAIHandler
ResponseBuilder <|-- OpenAIResponseBuilder ProtocolHandler <|-- AnthropicHandler
ResponseBuilder <|-- AnthropicResponseBuilder
%% --- Composition (strong ownership, part destroyed with whole) --- %% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool KVCache *-- PagePool
KVCache *-- Storage KVCache *-- Storage
KVCache *-- TaskTable KVCache *-- TaskTable
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor InferenceScheduler *-- Executor
@ -964,6 +963,7 @@ classDiagram
DecoderBlock *-- RMSNorm DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry BaseFactory *-- Registry
BaseExecutor *-- GradientState BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState AccumOptimizer o-- GradientState
@ -971,9 +971,6 @@ classDiagram
%% --- Aggregation (weak ownership) --- %% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback Trainer o-- TrainCallback
TrainContext o-- BaseStrategy TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler TrainContext o-- BaseScheduler
@ -981,7 +978,7 @@ classDiagram
TrainContext o-- BaseExecutor TrainContext o-- BaseExecutor
KvcacheView o-- Storage KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- Store BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) --- %% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects TrainConfig ..> BaseStrategy : selects
@ -995,14 +992,12 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses DecoderBlock ..> FFNFactory : uses
StoreFactory ..> H5Store : creates StorageFactory ..> H5Storage : creates
StoreFactory ..> JSONStore : creates StorageFactory ..> JSONStorage : creates
StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates TrainContextBuilder ..> TrainContext : creates
@ -1014,10 +1009,10 @@ classDiagram
KVCache ..> KvcacheView : binds KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates InferenceEngine ..> GenerateResult : creates
OpenAIResponseBuilder ..> ChatCompletionRequest : receives OpenAIHandler ..> ChatCompletionRequest : receives
AnthropicResponseBuilder ..> MessagesRequest : receives AnthropicHandler ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> GenContext : creates ProtocolHandler ..> StreamContext : creates
%% --- Association (general usage) --- %% --- Association (general usage) ---
Trainer --> TrainConfig Trainer --> TrainConfig
@ -1030,6 +1025,8 @@ classDiagram
Executor --> AutoModel Executor --> AutoModel
Executor --> AutoTokenizer Executor --> AutoTokenizer
TaskManager --> AutoTokenizer TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
``` ```
@ -1039,13 +1036,13 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization | | **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration | | **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers | | **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
@ -1053,17 +1050,17 @@ classDiagram
| Pattern | Classes | Purpose | | Pattern | Classes | Purpose |
|---------|---------|---------| |---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation | | **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks | | **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
| **Builder** | `TrainContextBuilder` | Chain-building training context | | **Builder** | `TrainContextBuilder` | Chain-building training context |
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1072,10 +1069,10 @@ classDiagram
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data` 7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops

View File

@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview ## Overview
``` ```
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
``` ```
## Data Preparation ## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups. Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
``` ```
StoreFactory.create("h5") → H5Store StorageFactory.create("h5") → H5Storage
StoreFactory.create("json") → JSONStore StorageFactory.create("json") → JSONStorage
StoreFactory.create("bin") → MmapStore
``` ```
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. Both support shared memory via `.share_memory_()`.
## Data Keys by Training Type ## Data Keys by Training Type
@ -34,14 +33,14 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us
## Dataset Architecture ## Dataset Architecture
``` ```
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer) DatasetFactory.load(train_type, path, window_size, stride)
→ StoreFactory.create(detect_format(path)) → StorageFactory.create(detect_format(path))
Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx) → BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx) → sliding window [begin, end) via get_index(idx)
``` ```
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`). `window_size` = max input length, `stride` = step between consecutive samples.
## Sampler ## Sampler

View File

@ -46,22 +46,20 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage. `sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Strategy Pattern) ## Protocol Handlers (Template Method)
```python ```python
class ProtocolHandler: # concrete orchestrator class ProtocolHandler(ABC):
def handle(self, request): def handle(self):
prompt, ctx, stops = builder.prepare(request, engine) ctx = StreamContext(...)
agen = engine.generate_async(prompt, ...) agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops) if stream: self._handle_stream(agen, ctx)
else: self._handle_non_stream(agen, ctx, stops) else: self._handle_non_stream(agen, ctx)
``` ```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`. Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`. `OpenAIHandler``/v1/chat/completions`, `AnthropicHandler``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult ## Engine & GenerateResult
@ -118,7 +116,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description | | Param | Type | Default | Description |
|-------|------|---------|-------------| |-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) | | `messages` | List[dict] | required | Chat messages (role, content) |
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) | | `temperature` | float | 1.0 | Sampling temperature (0.02.0) |
| `top_p` | float | 1.0 | Nucleus threshold | | `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count | | `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length | | `max_tokens` | int | None | Max generation length |

View File

@ -53,7 +53,7 @@
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 | | `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none | | `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
| `--device_type` | Device type | cuda | | `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn | | `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |

View File

@ -82,7 +82,8 @@ on_train_begin
on_optimizer_step on_optimizer_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
scheduler.step()
scheduler.step() # called every iteration
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
@ -189,7 +190,7 @@ context = (
``` ```
- Loads checkpoint weights if provided - Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` - Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers - Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume - Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)` - Builds strategy via `StrategyFactory.create(train_type, ...)`

View File

@ -17,8 +17,8 @@ def required(**kw):
@dataclass @dataclass
class TrainConfig(BaseConfig): class TrainConfig(BaseConfig):
# basic setting # basic setting
model_fn: Callable[[], nn.Module] = field( model: nn.Module = field(
default=None, metadata=required(help="Model factory for training.") default=None, metadata=required(help="Model for training.")
) )
strategy: str = field(default=None, metadata=required(help="Training strategy.")) strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field( dataset: Dataset = field(

View File

@ -4,17 +4,15 @@ from astrai.dataset.dataset import (
) )
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, BaseSegmentFetcher,
JSONStore, BaseStorage,
MmapStore, H5Storage,
Store, JSONStorage,
StoreFactory, MultiSegmentFetcher,
StorageFactory,
detect_format, detect_format,
json_to_bin,
load_bin,
load_h5, load_h5,
load_json, load_json,
save_bin,
save_h5, save_h5,
save_json, save_json,
) )
@ -22,18 +20,16 @@ from astrai.dataset.storage import (
__all__ = [ __all__ = [
"BaseDataset", "BaseDataset",
"DatasetFactory", "DatasetFactory",
"Store", "BaseSegmentFetcher",
"StoreFactory", "MultiSegmentFetcher",
"H5Store", "BaseStorage",
"JSONStore", "H5Storage",
"MmapStore", "JSONStorage",
"StorageFactory",
"detect_format", "detect_format",
"save_h5", "save_h5",
"load_h5", "load_h5",
"save_json", "save_json",
"load_json", "load_json",
"save_bin",
"load_bin",
"json_to_bin",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
Store, BaseStorage,
StoreFactory, StorageFactory,
detect_format, detect_format,
) )
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.storage: Optional[Store] = None self.storage: Optional[BaseStorage] = None
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
""" """
if storage_type is None: if storage_type is None:
storage_type = detect_format(load_path) storage_type = detect_format(load_path)
self.storage = StoreFactory.create(storage_type) self.storage = StorageFactory.create(storage_type)
self._load_path = load_path self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer) self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys() self._validate_keys()
@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
""" """
@classmethod @classmethod
def _validate_component(cls, dataset_cls: type): def _validate_component(cls, dataset_cls: type) -> None:
"""Validate that the dataset class inherits from BaseDataset.""" """Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset): if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")

View File

@ -1,20 +1,7 @@
"""Storage backends for different data formats. """Storage backends for different data formats.
Layers: Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin) a uniform interface for data access and length observation via fetchers.
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
""" """
import bisect import bisect
@ -25,7 +12,6 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import h5py import h5py
import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
@ -118,38 +104,6 @@ def load_json(
return tensor_group return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json"))
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json"))
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str: def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory. """Auto-detect storage format from files in the directory.
@ -157,7 +111,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path load_path: Directory or file path
Returns: Returns:
Format string ("h5", "bin", or "json") Format string ("h5" or "json")
Raises: Raises:
FileNotFoundError: If no supported data files are found FileNotFoundError: If no supported data files are found
@ -174,118 +128,166 @@ def detect_format(load_path: str) -> str:
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files: if h5_files:
return "h5" return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files: if json_files:
return "json" return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}") raise FileNotFoundError(f"No supported data files found at {load_path}")
class Store(ABC): class BaseSegmentFetcher:
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``. """Fetches data segments across multiple tensor segments.
Each key maps to one or more tensor segments (no forced concatenation). Maintains cumulative lengths for efficient range queries across
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum multiple discontinuous segments.
total element count across all keys. """
Subclasses fill ``self._data`` and ``self._cum`` during ``load()`` def __init__(self, segments: List[Tensor]):
via ``_normalize()``. self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
""" """
def __init__(self): def __init__(self):
self._data: Dict[str, List[Tensor]] = {} self._fetcher: Optional[MultiSegmentFetcher] = None
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod @abstractmethod
def load(self, path: str, tokenizer=None) -> None: def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property @property
def keys(self) -> List[str]: def keys(self) -> List[str]:
return list(self._data.keys()) """Return the data keys available in this storage."""
if self._fetcher is None:
def __len__(self) -> int: return []
return self._length return self._fetcher.multi_keys
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
class StoreFactory(BaseFactory["Store"]): class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating Store instances by type name. """Factory for creating storage backends by type name.
Example:: Example:
@StorageFactory.register("custom")
@StoreFactory.register("custom") class CustomStorage(BaseStorage):
class CustomStore(Store):
... ...
storage = StorageFactory.create("custom")
""" """
@classmethod @classmethod
def _validate_component(cls, store_cls: type): def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(store_cls, Store): if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{store_cls.__name__} must inherit from Store") raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@StoreFactory.register("h5") @StorageFactory.register("h5")
class H5Store(Store): class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data).""" """HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str, tokenizer=None): def load(self, load_path: str, tokenizer=None) -> None:
self._normalize(load_h5(path)) segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
@StoreFactory.register("json") @StorageFactory.register("json")
class JSONStore(Store): class JSONStorage(BaseStorage):
"""JSON-based storage backend. """JSON-based storage backend.
Supports two modes: Supports two modes:
@ -294,28 +296,6 @@ class JSONStore(Store):
callable (str -> List[int]) at load time. callable (str -> List[int]) at load time.
""" """
def load(self, path: str, tokenizer=None): def load(self, load_path: str, tokenizer=None) -> None:
self._normalize(load_json(path, tokenizer=tokenizer)) segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(path)
self._normalize(raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)

View File

@ -23,7 +23,7 @@ class Registry:
component_cls: Type, component_cls: Type,
category: Optional[str] = None, category: Optional[str] = None,
priority: int = 0, priority: int = 0,
): ) -> None:
"""Register a component class with optional category and priority.""" """Register a component class with optional category and priority."""
if name in self._entries: if name in self._entries:
raise ValueError(f"Component '{name}' is already registered") raise ValueError(f"Component '{name}' is already registered")
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
return component_cls(*args, **kwargs) return component_cls(*args, **kwargs)
@classmethod @classmethod
def _validate_component(cls, component_cls: Type[T]): def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory. """Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation. Override this method in subclasses to add custom validation.

View File

@ -42,7 +42,7 @@ class Allocator:
return idx return idx
return -1 return -1
def free(self, idx: int, keep_cached: bool = False): def free(self, idx: int, keep_cached: bool = False) -> None:
with self._lock: with self._lock:
self._refs[idx] -= 1 self._refs[idx] -= 1
if self._refs[idx] == 0: if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else: else:
self._free_mask |= 1 << idx self._free_mask |= 1 << idx
def inc_ref(self, idx: int): def inc_ref(self, idx: int) -> None:
with self._lock: with self._lock:
self._refs[idx] += 1 self._refs[idx] += 1
self._lru.pop(idx, None) self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock: with self._lock:
return self._refs[idx] return self._refs[idx]
def touch(self, idx: int): def touch(self, idx: int) -> None:
with self._lock: with self._lock:
self._lru.move_to_end(idx) self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def evict(self, idx: int): def evict(self, idx: int) -> None:
with self._lock: with self._lock:
h = self._page_to_hash.pop(idx, None) h = self._page_to_hash.pop(idx, None)
if h is not None: if h is not None:
@ -96,7 +96,9 @@ class PrefixCache:
hits.append(p) hits.append(p)
return hits return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int): def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock: with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size) h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None) old_h = self._page_to_hash.pop(page_idx, None)
@ -125,13 +127,13 @@ class PagePool:
def alloc(self) -> int: def alloc(self) -> int:
return self._alloc.alloc() return self._alloc.alloc()
def free(self, idx: int): def free(self, idx: int) -> None:
keep = self._prefix.has_page(idx) keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep) self._alloc.free(idx, keep_cached=keep)
if not keep: if not keep:
self._prefix.evict(idx) self._prefix.evict(idx)
def inc_ref(self, idx: int): def inc_ref(self, idx: int) -> None:
self._alloc.inc_ref(idx) self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]: def lookup(self, token_ids: List[int]) -> List[int]:
@ -140,7 +142,9 @@ class PagePool:
self._alloc.touch(p) self._alloc.touch(p)
return hits return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int): def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx) self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -153,7 +157,7 @@ class TaskTable:
self._cached: Dict[str, int] = {} self._cached: Dict[str, int] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int): def set(self, task_id: str, page_table: List[int], cached: int) -> None:
with self._lock: with self._lock:
self._pages[task_id] = page_table self._pages[task_id] = page_table
self._cached[task_id] = cached self._cached[task_id] = cached
@ -216,7 +220,7 @@ class Storage:
start_pos: int, start_pos: int,
k: Tensor, k: Tensor,
v: Tensor, v: Tensor,
): ) -> None:
seq_len = k.size(1) seq_len = k.size(1)
if seq_len == 0: if seq_len == 0:
return return
@ -282,7 +286,7 @@ class KvcacheView:
self._page_table = page_table self._page_table = page_table
self._total_len = total_len self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor): def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1) start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v) self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -335,7 +339,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached) self._table.set(task_id, hits + new_pages, cached)
return True return True
def task_free(self, task_id: str): def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id) page_table, _ = self._table.pop(task_id)
for idx in page_table: for idx in page_table:
self._pool.free(idx) self._pool.free(idx)
@ -355,7 +359,7 @@ class KVCache:
def task_record_hashes( def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
): ) -> None:
page_table = self._table.get(task_id) page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages): for i in range(start_logical_page, full_pages):

View File

@ -29,7 +29,9 @@ class Executor:
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0): def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len: if start_pos >= prompt_len:
return return

View File

@ -75,14 +75,14 @@ class InferenceScheduler:
def add_task(self, prompt: str, **kwargs) -> str: def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str): def remove_task(self, task_id: str) -> None:
for task in self._task_mgr.remove_task(task_id): for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats() return self._task_mgr.get_stats()
def _run_generation_loop(self): def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids stop_ids = self._task_mgr.tokenizer.stop_ids
try: try:
while self._running: while self._running:
@ -186,14 +186,14 @@ class InferenceScheduler:
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise raise
def start(self): def start(self) -> None:
if not self._running: if not self._running:
self._running = True self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True) t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start() t.start()
self._loop_thread = t self._loop_thread = t
def stop(self): def stop(self) -> None:
self._running = False self._running = False
self._task_mgr.wake() self._task_mgr.wake()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft()) to_add.append(self.waiting_queue.popleft())
return to_add return to_add
def activate(self, task: Task): def activate(self, task: Task) -> None:
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
with self._lock: with self._lock:
self.active_tasks.append(task) self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]): def return_to_waiting(self, tasks: List[Task]) -> None:
with self._lock: with self._lock:
for task in reversed(tasks): for task in reversed(tasks):
self.waiting_queue.appendleft(task) self.waiting_queue.appendleft(task)
@ -185,7 +185,7 @@ class TaskManager:
def has_work(self) -> bool: def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue) return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0): def wait_for_tasks(self, timeout: float = 1.0) -> None:
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)
@ -197,10 +197,10 @@ class TaskManager:
with self._lock: with self._lock:
return list(self.waiting_queue) return list(self.waiting_queue)
def clear_queues(self): def clear_queues(self) -> None:
with self._lock: with self._lock:
self.waiting_queue.clear() self.waiting_queue.clear()
self.active_tasks.clear() self.active_tasks.clear()
def wake(self): def wake(self) -> None:
self._task_event.set() self._task_event.set()

View File

@ -48,7 +48,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool: def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0): def wait_completion(self, timeout: float = 300.0) -> None:
with self._cond: with self._cond:
if not self._cond.wait_for( if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout lambda: self._completed >= self._total, timeout=timeout
@ -281,7 +281,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self): def shutdown(self) -> None:
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -15,11 +15,7 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
@contextmanager @contextmanager
def _disable_random_init(enable: bool = True): def _disable_random_init(enable: bool = True):
if not enable: init_functions = [
yield
return
names = (
"xavier_normal_", "xavier_normal_",
"xavier_uniform_", "xavier_uniform_",
"kaiming_normal_", "kaiming_normal_",
@ -29,15 +25,18 @@ def _disable_random_init(enable: bool = True):
"constant_", "constant_",
"normal_", "normal_",
"uniform_", "uniform_",
) ]
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)} original_funcs = {}
for n in orig: for name in init_functions:
setattr(nn.init, n, lambda *a, **kw: None) if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try: try:
yield yield
finally: finally:
for n, fn in orig.items(): if enable:
setattr(nn.init, n, fn) for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class AutoModel(BaseFactory["AutoModel"], nn.Module): class AutoModel(BaseFactory["AutoModel"], nn.Module):
@ -83,7 +82,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, Path], save_directory: Union[str, Path],
): ) -> None:
save_model( save_model(
config=self.config.to_dict(), config=self.config.to_dict(),
state_dict=self.state_dict(), state_dict=self.state_dict(),

View File

@ -68,6 +68,9 @@ class EmbeddingEncoder(AutoModel):
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids) rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False) attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Mapping, Optional from typing import Any, Mapping, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -136,7 +136,7 @@ class AutoRegressiveLM(AutoModel):
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None, paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None,
) -> Dict[str, Tensor]: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)

View File

@ -203,45 +203,9 @@ class DDPExecutor(BaseExecutor):
@ExecutorFactory.register("fsdp") @ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor): class FSDPExecutor(BaseExecutor):
def __init__( def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps) super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = { self._fsdp_kwargs = fsdp_kwargs
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module: def _prepare_model(self, model: nn.Module) -> nn.Module:

View File

@ -1,9 +1,8 @@
import io
import json import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -12,11 +11,11 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank
_META_FILE = "meta.json" _META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors" _WEIGHTS_FILE = "model.safetensors"
_MODEL_CONFIG_FILE = "config.json"
def save_safetensors(state_dict: dict, path: str | Path): def save_safetensors(state_dict: dict, path: str | Path) -> None:
st.save_file(state_dict, str(path)) st.save_file(state_dict, str(path))
@ -24,7 +23,7 @@ def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path)) return st.load_file(str(path))
def save_json(data: dict, path: str | Path): def save_json(data: dict, path: str | Path) -> None:
with open(str(path), "w") as f: with open(str(path), "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
@ -34,92 +33,13 @@ def load_json(path: str | Path) -> dict:
return json.load(f) return json.load(f)
def save_torch(obj: Any, path: str | Path): def save_torch(obj: Any, path: str | Path) -> None:
torch.save(obj, str(path)) torch.save(obj, str(path))
def load_torch(path: str | Path, broadcast: bool = False) -> Any: def load_torch(path: str | Path) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False) return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict:
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs: List[Tuple[str, List[int], str]] = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass @dataclass
class Checkpoint: class Checkpoint:
@ -129,7 +49,7 @@ class Checkpoint:
extra: Dict[str, Any] = field(default_factory=dict) extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str): def save(self, save_dir: str) -> None:
save_path = Path(save_dir) save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
@ -148,16 +68,24 @@ class Checkpoint:
save_torch(value, save_path / f"{key}.pt") save_torch(value, save_path / f"{key}.pt")
@classmethod @classmethod
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": def load(cls, save_dir: str) -> "Checkpoint":
save_path = Path(save_dir) save_path = Path(save_dir)
meta = _get_meta(save_path) meta = {}
state_dict = _load_state_dict(save_path, broadcast=broadcast) if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
extra = {} extra = {}
for f in sorted(save_path.iterdir()): for f in save_path.iterdir():
if f.suffix == ".pt": if f.suffix == ".pt":
extra[f.stem] = load_torch(f, broadcast=broadcast) extra[f.stem] = load_torch(f)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
@ -165,3 +93,18 @@ class Checkpoint:
iteration=meta.get("iteration", 0), iteration=meta.get("iteration", 0),
extra=extra, extra=extra,
) )
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _MODEL_CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
""" """
@classmethod @classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]): def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Validate that the scheduler class inherits from BaseScheduler.""" """Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler): if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
""" """
@classmethod @classmethod
def _validate_component(cls, strategy_cls: type): def _validate_component(cls, strategy_cls: type) -> None:
"""Validate that the strategy class inherits from BaseStrategy.""" """Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy): if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device, get_rank from astrai.parallel.setup import get_current_device
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_grad_max, ctx_get_grad_max,
@ -139,27 +139,27 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
# All ranks gather state_dict — collective for FSDP, local for DDP save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = ( state_dict = (
self.state_dict_fn(context.model) self.state_dict_fn(context.model)
if self.state_dict_fn if self.state_dict_fn
else context.model.state_dict() else context.model.state_dict()
) )
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context) extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, state_dict=state_dict,
@ -168,7 +168,13 @@ class CheckpointCallback(TrainCallback):
extra=extra, extra=extra,
meta=context.config.to_dict(), meta=context.config.to_dict(),
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval: if context.iteration - self.last_ckpt_iter >= self.interval:
@ -190,6 +196,12 @@ class CheckpointCallback(TrainCallback):
extra[name] = obj.state_dict() extra[name] = obj.state_dict()
return extra return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar") @CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self from typing import Optional, Self
import torch.nn as nn import torch.nn as nn
@ -11,7 +10,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_model_weights from astrai.serialization import Checkpoint
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -43,10 +42,10 @@ class TrainContextBuilder:
config: TrainConfig, config: TrainConfig,
): ):
self.config = config self.config = config
self._resume_dir: Optional[str] = None self._checkpoint: Optional[Checkpoint] = None
def with_resume_dir(self, resume_dir: Optional[str]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._resume_dir = resume_dir self._checkpoint = checkpoint
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
@ -59,40 +58,36 @@ class TrainContextBuilder:
**cfg.executor_kwargs, **cfg.executor_kwargs,
) )
model = cfg.model_fn()
model = model.to(device=device)
context = TrainContext( context = TrainContext(
model=model, model=cfg.model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=cfg, config=cfg,
executor=executor, executor=executor,
) )
if self._resume_dir is not None: context.model = context.model.to(device=device)
resume_path = Path(self._resume_dir)
if (resume_path / "meta.json").exists(): if self._checkpoint is not None:
checkpoint = Checkpoint.load(self._resume_dir) context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
state_dict = checkpoint.state_dict context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
if self._checkpoint.state_dict:
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else: else:
checkpoint = None context.checkpoint = Checkpoint(
state_dict = load_model_weights(self._resume_dir) state_dict=context.model.state_dict(),
model.load_state_dict(state_dict, strict=False) )
if checkpoint is not None:
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
context.iteration = max(checkpoint.iteration, cfg.start_batch)
context.checkpoint = checkpoint
if cfg.lora is not None: if cfg.lora is not None:
inject_lora( inject_lora(
model, context.model,
r=cfg.lora.r, r=cfg.lora.r,
alpha=cfg.lora.alpha, alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules), target_modules=set(cfg.lora.target_modules),
) )
context.optimizer = cfg.optimizer_fn(model) context.optimizer = cfg.optimizer_fn(context.model)
context.scheduler = cfg.scheduler_fn(context.optimizer) context.scheduler = cfg.scheduler_fn(context.optimizer)
sampler_offset = context.iteration * cfg.batch_per_device sampler_offset = context.iteration * cfg.batch_per_device
@ -130,21 +125,13 @@ class TrainContextBuilder:
context.model, context.optimizer, context.dataloader, context.scheduler = ( context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare( executor.prepare(
model, context.model,
context.optimizer, context.optimizer,
context.dataloader, context.dataloader,
context.scheduler, context.scheduler,
) )
) )
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create( context.strategy = StrategyFactory.create(
model=context.model, model=context.model,
train_type=cfg.strategy, train_type=cfg.strategy,

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from astrai.config import TrainConfig from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import ( from astrai.trainer.train_callback import (
CallbackFactory, CallbackFactory,
TrainCallback, TrainCallback,
@ -53,9 +54,9 @@ class Trainer:
if method: if method:
method(context) method(context)
def _trainer_loop(self, resume_dir: Optional[str] = None): def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
context = ( context = (
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build() TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
) )
executor = context.executor executor = context.executor
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
@ -89,13 +90,13 @@ class Trainer:
self._call_callbacks("on_epoch_end", context) self._call_callbacks("on_epoch_end", context)
except Exception as e: except Exception as e:
logger.error("Training failed: %s", str(e), exc_info=True) logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks("on_error", context) self._call_callbacks("on_error", context)
raise raise
finally: finally:
self._call_callbacks("on_train_end", context) self._call_callbacks("on_train_end", context)
def train(self, resume_dir: Optional[str] = None): def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config cfg = self.train_config
spawn_parallel_fn( spawn_parallel_fn(
self._trainer_loop, self._trainer_loop,
@ -105,5 +106,5 @@ class Trainer:
master_port=cfg.master_port, master_port=cfg.master_port,
device_type=cfg.device_type, device_type=cfg.device_type,
start_method=cfg.start_method, start_method=cfg.start_method,
resume_dir=resume_dir, checkpoint=checkpoint,
) )

View File

@ -1,279 +0,0 @@
"""MMLU evaluation via log-likelihood ranking."""
import argparse
import csv
import json
import os
import shutil
import urllib.request
import zipfile
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _download_and_extract(url: str, data_dir: str):
zip_path = os.path.join(data_dir, "mmlu.zip")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
urllib.request.urlretrieve(url, zip_path)
print("Extracting...")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(data_dir)
os.remove(zip_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "test-master", "data")
if os.path.exists(src):
for item in os.listdir(src):
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
shutil.rmtree(os.path.join(data_dir, "test-master"))
print(f"MMLU data saved to {data_dir}")
def _strip_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :].strip()
return text
def load_csv(path: str) -> list[dict]:
data = []
with open(path, "r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 6:
continue
if row[0].strip().lower() == "question":
continue
data.append(
{
"question": row[0].strip(),
"A": _strip_prefix(row[1].strip(), "A)"),
"B": _strip_prefix(row[2].strip(), "B)"),
"C": _strip_prefix(row[3].strip(), "C)"),
"D": _strip_prefix(row[4].strip(), "D)"),
"answer": row[5].strip(),
}
)
return data
def build_prompt(
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
) -> str:
prompt = ""
if n_shot > 0 and dev_data:
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
for item in dev_data[:n_shot]:
prompt += f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {item[k]}\n"
prompt += f"Answer: {item['answer']}\n\n"
prompt += f"Question: {question}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {choices[k]}\n"
prompt += "Answer:"
return prompt
def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float:
choice_text = f" {choice_letter}"
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids
max_len = model.config.max_len
if len(input_ids) > max_len:
overflow = len(input_ids) - max_len
input_ids = input_ids[overflow:]
ctx_len = len(input_ids) - len(choice_ids)
else:
ctx_len = len(context_ids)
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits = model(input_tensor)["logits"][0]
score = 0.0
for i, tid in enumerate(choice_ids):
pos = ctx_len - 1 + i
if pos >= len(logits):
break
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
return score
def evaluate_subject(
model,
tokenizer,
subject: str,
test_data: list[dict],
dev_data: list[dict] | None,
device: str,
n_shot: int,
) -> tuple[float, int, int]:
correct = 0
total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or [])
context_ids = tokenizer.encode(prompt)
scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D")
}
if max(scores, key=scores.get) == item["answer"]:
correct += 1
total += 1
return correct / total, correct, total
def main():
parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
)
parser.add_argument("--download", action="store_true", help="Download MMLU data")
parser.add_argument(
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
)
parser.add_argument(
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
)
parser.add_argument("--output", type=str, help="Output JSON path")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16" if torch.cuda.is_available() else "float32",
help="Torch dtype",
)
args = parser.parse_args()
if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
subjects = args.subjects or MMLU_SUBJECTS
results = {}
total_correct = 0
total_questions = 0
for subject in subjects:
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
test_path = os.path.join(
args.data_dir, args.split, f"{subject}_{args.split}.csv"
)
if not os.path.exists(test_path):
print(f" Skipping {subject}: test file not found")
continue
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
test_data = load_csv(test_path)
acc, corr, tot = evaluate_subject(
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
)
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
total_correct += corr
total_questions += tot
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
overall = total_correct / total_questions if total_questions else 0
print(f"\n{'=' * 70}")
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
results["_overall"] = {
"accuracy": round(overall, 4),
"correct": total_correct,
"total": total_questions,
}
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file( def process_file(
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(param_path) model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f: with open(input_file, "r", encoding="utf-8") as f:
@ -44,8 +44,8 @@ def process_file(
for seq in batch_encoded: for seq in batch_encoded:
pad_len = max_len - len(seq) pad_len = max_len - len(seq)
padded_seq = seq + [tokenizer.pad_id] * pad_len padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [True] * len(seq) + [False] * pad_len mask = [False] * pad_len + [True] * len(seq)
padded_ids.append(padded_seq) padded_ids.append(padded_seq)
masks.append(mask) masks.append(mask)
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument( parser.add_argument(
"--param_path", type=str, required=True, help="Path to the model directory." "--model_dir", type=str, required=True, help="Path to the model directory."
) )
parser.add_argument( parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file." "--input_file", type=str, required=True, help="Path to the input file."

View File

@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development" "--reload", action="store_true", help="Enable auto-reload for development"
) )
parser.add_argument( parser.add_argument(
"--param_path", "--param-path",
type=Path, type=Path,
default=None, default=None,
help="Path to model parameters (default: project_root/params)", help="Path to model parameters (default: project_root/params)",

View File

@ -8,6 +8,7 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -146,8 +147,8 @@ def parse_args() -> argparse.Namespace:
"--parallel_mode", "--parallel_mode",
type=str, type=str,
default="none", default="none",
choices=["none", "ddp", "fsdp"], choices=["none", "ddp"],
help="Parallel training strategy (none, ddp, fsdp).", help="Parallel training strategy.",
) )
parser.add_argument( parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use." "--device_type", type=str, default="cuda", help="Device type to use."
@ -165,10 +166,6 @@ def parse_args() -> argparse.Namespace:
return args return args
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def create_optimizer(model, **kwargs) -> optim.Optimizer: def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs) return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -231,8 +228,6 @@ def train(
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config # Load config
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")
@ -241,6 +236,15 @@ def train(
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
# (model weights already loaded into model above)
checkpoint.state_dict = {}
strategy_kwargs = { strategy_kwargs = {
"beta": dpo_beta, "beta": dpo_beta,
"label_smoothing": label_smoothing, "label_smoothing": label_smoothing,
@ -255,7 +259,6 @@ def train(
"broadcast_buffers": False, "broadcast_buffers": False,
} }
model_fn = partial(create_model, config)
dataset = DatasetFactory.load( dataset = DatasetFactory.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
@ -287,7 +290,7 @@ def train(
) )
train_config = TrainConfig( train_config = TrainConfig(
model_fn=model_fn, model=model,
strategy=train_type, strategy=train_type,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
@ -312,7 +315,7 @@ def train(
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train(resume_dir=param_path) trainer.train(checkpoint=checkpoint)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,8 +7,10 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, BaseSegmentFetcher,
StoreFactory, H5Storage,
MultiSegmentFetcher,
StorageFactory,
detect_format, detect_format,
load_json, load_json,
save_h5, save_h5,
@ -316,48 +318,37 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0 assert len(dataset) == 0
def test_store_unloaded_len(): def test_base_segment_fetcher_empty():
"""Unloaded Store has __len__ == 0""" """BaseSegmentFetcher with empty segments list"""
store = H5Store() fetcher = BaseSegmentFetcher([])
assert len(store) == 0 assert len(fetcher) == 0
assert store.keys == [] with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_store_fetch_begin_equals_end(base_test_env): def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""Store.fetch with begin == end returns empty tensor""" """fetch_data with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]} dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy) save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32) dataset = DatasetFactory.load("seq", test_dir, window_size=32)
result = dataset.storage.fetch(10, 10, "sequence") fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0 assert result.numel() == 0
def test_store_empty_data_len(base_test_env): def test_multi_segment_fetcher_empty_dict():
"""Store loaded with empty data has __len__ == 0""" """MultiSegmentFetcher with empty dict has __len__ == 0"""
import os fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_store_fetch_before_load(): def test_storage_fetch_before_load():
"""Store.fetch before load raises RuntimeError""" """BaseStorage.fetch before load raises RuntimeError"""
store = H5Store() storage = H5Storage()
with pytest.raises(RuntimeError, match="not loaded"): with pytest.raises(RuntimeError, match="not loaded"):
store.fetch(0, 10, "sequence") storage.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path(): def test_detect_format_nonexistent_path():
@ -376,10 +367,10 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path) detect_format(path)
def test_create_store_invalid_type(): def test_create_storage_invalid_type():
"""StoreFactory.create raises ValueError for unknown type""" """StorageFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"): with pytest.raises(ValueError, match="Unknown component"):
StoreFactory.create("parquet") StorageFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env): def test_json_pretokenized_without_tokenizer(base_test_env):
@ -416,23 +407,14 @@ def test_load_json_skips_config_file(base_test_env):
assert len(result["sequence"]) == 1 assert len(result["sequence"]) == 1
def test_store_multi_segment_concat(base_test_env): def test_base_segment_fetcher_multi_segment():
"""Multi-segment H5 data is concatenated into single tensor at load time""" """fetch_data across multiple segment boundaries"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
segs = [ segs = [
torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]), torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]), torch.tensor([8, 9]),
] ]
save_h5(data_dir, "data", {"sequence": segs}) fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
store = StoreFactory.create("h5") result = fetcher.fetch_data(2, 7)
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7] assert result.tolist() == [3, 4, 5, 6, 7]

View File

@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
def create_train_config( def create_train_config(
model_fn, model: torch.nn.Module,
dataset: Dataset, dataset: Dataset,
test_dir: str, test_dir: str,
device: str, device: str,
@ -43,7 +43,7 @@ def create_train_config(
"""Factory function to create common TrainConfig for tests. """Factory function to create common TrainConfig for tests.
Args: Args:
model_fn: Model factory (callable returning nn.Module) model: The model to train
dataset: Training dataset dataset: Training dataset
test_dir: Checkpoint directory test_dir: Checkpoint directory
device: Device type ("cuda" or "cpu") device: Device type ("cuda" or "cpu")
@ -70,7 +70,7 @@ def create_train_config(
return TrainConfig( return TrainConfig(
strategy=strategy, strategy=strategy,
model_fn=model_fn, model=model,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,

View File

@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
) )
train_config = TrainConfig( train_config = TrainConfig(
model_fn=lambda: base_test_env["model"], model=base_test_env["model"],
strategy="seq", strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
) )
train_config = TrainConfig( train_config = TrainConfig(
model_fn=lambda: base_test_env["model"], model=base_test_env["model"],
strategy="seq", strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,

View File

@ -4,6 +4,7 @@ import numpy as np
import torch import torch
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer from astrai.trainer.trainer import Trainer
@ -23,7 +24,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
strategy="seq", strategy="seq",
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
model_fn=lambda: base_test_env["model"], model=base_test_env["model"],
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"), log_dir=os.path.join(base_test_env["test_dir"], "logs"),
@ -38,20 +39,17 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
trainer = Trainer(train_config) trainer = Trainer(train_config)
# Should handle early stopping gracefully # Should handle early stopping gracefully
checkpoint = None
try: try:
trainer.train() checkpoint = trainer.train()
except Exception: except Exception:
# Handle any exceptions
pass pass
# Resume from latest checkpoint
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
trainer = Trainer(train_config) checkpoint = Checkpoint.load(load_dir)
trainer.train(resume_dir=load_dir) trainer.train(checkpoint)
# Verify checkpoint was saved at expected iteration
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
import json checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10
with open(os.path.join(load_dir, "meta.json")) as f:
meta = json.load(f)
assert meta["iteration"] == 10

View File

@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
for batch_per_device in batch_sizes: for batch_per_device in batch_sizes:
train_config = train_config_factory( train_config = train_config_factory(
model_fn=lambda: base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
for grad_accum_steps in grad_accum_steps_list: for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory( train_config = train_config_factory(
model_fn=lambda: base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
for config in small_batch_configs: for config in small_batch_configs:
train_config = train_config_factory( train_config = train_config_factory(
model_fn=lambda: base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],