Compare commits

..

9 Commits

Author SHA1 Message Date
ViperEkura 0a708fff24 docs : 更新架构文档与 storage 注释,同步 Store 重构
- architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore)
- architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系
- dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构
- storage.py: module docstring 改用缩进式注释风格
2026-05-28 14:36:18 +08:00
ViperEkura 6e150ea6d0 refactor : Storage 层重构为 Store,移除 Fetcher 中间层,支持多段数据与显式长度
- 合并 BaseStorage + MultiSegmentFetcher + BaseSegmentFetcher 三层为 Store ABC
- Store._data 直接持有 Dict[str, List[Tensor]],不做强制拼接避免 OOM
- _fetch_key 统一用 bisect 跨段切片,单段多段同一路径
- _length 显式存储(min total across keys),__len__ 返回 O(1)
- MmapStore/H5Store/JSONStore 统一走 _normalize() 注册分段并预计算累积长度
- 所有 I/O 函数 (save_h5/load_h5/json_to_bin 等) 保持不变
2026-05-28 14:23:49 +08:00
ViperEkura cb8dcb97ea refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储
- 删除所有 def 函数 -> None 返回值类型标注
- FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤
- 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载
- 新增 save_bin/load_bin/json_to_bin 工具函数
- detect_format 支持 bin 格式自动检测
2026-05-28 13:57:06 +08:00
ViperEkura 2d5dc93b3d fix : 修正类型标注与统一 CLI 参数命名
- AutoRegressiveLM.forward 返回类型标注 -> Dict[str, Tensor]
- EmbeddingEncoder 移除冗余 position_ids 自动创建
- CLI 脚本模型目录参数统一为 --param_path
2026-05-27 20:49:44 +08:00
ViperEkura 4145d35e3c refactor: 检查点加载重构,路径替代对象传递
- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串
- Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递
- TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch
- CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁
- serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast
- optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行
- pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
2026-05-27 20:15:29 +08:00
ViperEkura 34c6c45bd6 feat: 初步实现 MMLU 评测脚本
- 支持 few-shot (log-likelihood ranking) 与 zero-shot
- 自动下载 Hendrycks MMLU 数据集
- --device / --dtype 可配置,默认 GPU bf16
2026-05-26 20:23:31 +08:00
ViperEkura e9def84ce7 fix : perplexity.py left padding 导致 batch>1 时 PPL 计算错误 2026-05-26 19:59:57 +08:00
ViperEkura 836e02a166 docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项
- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式
- 修正训练循环 scheduler.step() 在 sync_gradients 块内
- 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联
- --parallel_mode CLI choices 加入 fsdp
- nprocs > 1 且 parallel_mode=none 时 raise error
2026-05-26 19:37:00 +08:00
ViperEkura b558e61f63 refactor: 简化 _disable_random_init,scheduler 移入同步块
- _disable_random_init: enable=False 提前返回,dict 推导替代空字典
- scheduler.step() 移入 sync_gradients 守卫内
2026-05-26 17:05:25 +08:00
34 changed files with 885 additions and 475 deletions

View File

@ -22,7 +22,8 @@ classDiagram
+int n_layers +int n_layers
+float norm_eps +float norm_eps
+int dim_ffn +int dim_ffn
+bool tie_weight +Optional[bool] tie_weight
+Optional[dict] rope_scaling
+int max_len +int max_len
+float rope_theta +float rope_theta
+str attn_type +str attn_type
@ -52,6 +53,7 @@ classDiagram
+int n_kv_heads +int n_kv_heads
+bool use_qk_norm +bool use_qk_norm
+bool use_gated_attention +bool use_gated_attention
+Optional[dict] rope_scaling
+Optional[str] pooling_type +Optional[str] pooling_type
+Optional[bool] normalize_embeddings +Optional[bool] normalize_embeddings
} }
@ -80,6 +82,7 @@ classDiagram
+str log_dir +str log_dir
+int log_interval +int log_interval
+List[str] metrics +List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed +int random_seed
+int num_workers +int num_workers
+Optional[int] prefetch_factor +Optional[int] prefetch_factor
@ -104,7 +107,7 @@ classDiagram
class BaseDataset { class BaseDataset {
+int window_size +int window_size
+int stride +int stride
+Optional[BaseStorage] storage +Optional[Store] storage
+load(load_path, storage_type, tokenizer) +load(load_path, storage_type, tokenizer)
+__getitem__(index) +__getitem__(index)
+__len__() +__len__()
@ -126,38 +129,29 @@ classDiagram
+__getitem__(index) Dict +__getitem__(index) Dict
} }
class BaseSegmentFetcher { class Store {
+List[Tensor] segments +Dict[str, List[Tensor]] _data
+List[int] cum_lengths +Dict[str, List[int]] _cum
+int total_length +int _length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
+keys (property) +keys (property)
+load(load_path, tokenizer) +load(path, tokenizer)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
-_fetch_key(key, begin, end) Tensor
-_normalize(raw)
} }
class H5Storage { class H5Store {
+load(load_path, tokenizer) +load(path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
} }
class JSONStorage { class JSONStore {
+load(load_path, tokenizer) +load(path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
} }
class MultiSegmentFetcher { class MmapStore {
+Dict multi_fetchers +List _mmap_refs
+List multi_keys +load(path, tokenizer)
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
} }
class ResumableDistributedSampler { class ResumableDistributedSampler {
@ -165,10 +159,10 @@ classDiagram
+int iter +int iter
} }
class StorageFactory { class StoreFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
+create(storage_type) BaseStorage +create(storage_type) Store
} }
class DatasetFactory { class DatasetFactory {
@ -457,16 +451,15 @@ classDiagram
+on_train_end(context) +on_train_end(context)
+on_epoch_begin(context) +on_epoch_begin(context)
+on_epoch_end(context) +on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context) +on_batch_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_optimizer_step(context)
+on_error(context) +on_error(context)
} }
class GradientClippingCallback { class GradientClippingCallback {
+float max_grad_norm +float max_grad_norm
+on_step_begin(context) +on_optimizer_step(context)
} }
class GradientCheckpointingCallback { class GradientCheckpointingCallback {
@ -512,7 +505,7 @@ classDiagram
class ValidationCallback { class ValidationCallback {
+_run_validation(context) +_run_validation(context)
+on_step_end(context) +on_optimizer_step(context)
} }
class CallbackFactory { class CallbackFactory {
@ -747,56 +740,58 @@ classDiagram
+str model +str model
+List[AnthropicMessage] messages +List[AnthropicMessage] messages
+Optional[str] system +Optional[str] system
+float temperature +Optional[float] temperature
+float top_p +Optional[float] top_p
+int top_k +Optional[int] top_k
+int max_tokens +int max_tokens
+bool stream +Optional[bool] stream
+Optional[List[str]] stop_sequences +Optional[List[str]] stop_sequences
} }
class ProtocolHandler { class ResponseBuilder {
<<abstract>> <<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler {
+request +request
+engine +engine
+build_prompt() str +builder: ResponseBuilder
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict] +handle() Union[StreamingResponse, Dict]
} -_handle_stream(agen, ctx, stops) StreamingResponse
-_handle_non_stream(agen, ctx, stops) Dict
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
} }
class StopChecker { class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str] +check(text) Optional[str]
+trim(text, matched) str
} }
class StreamContext { class GenContext {
+str resp_id +str resp_id
+int created +int created
+str model +str model
+int prompt_tokens +int prompt_tokens
+int completion_tokens +int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
} }
class app { class app {
@ -876,6 +871,11 @@ classDiagram
+unwrap_model(model) nn.Module +unwrap_model(model) nn.Module
} }
class FSDPExecutor {
+_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module
}
class ExecutorFactory { class ExecutorFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
@ -911,12 +911,14 @@ classDiagram
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset BaseDataset <|-- GRPODataset
BaseStorage <|-- H5Storage Store <|-- H5Store
BaseStorage <|-- JSONStorage Store <|-- JSONStore
Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy BaseSamplingStrategy <|-- TopPStrategy
@ -936,20 +938,19 @@ classDiagram
BaseFactory <|-- StrategyFactory BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory BaseFactory <|-- StoreFactory
BaseFactory <|-- ExecutorFactory BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor BaseExecutor <|-- DDPExecutor
ProtocolHandler <|-- OpenAIHandler BaseExecutor <|-- FSDPExecutor
ProtocolHandler <|-- AnthropicHandler ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder
%% --- Composition (strong ownership, part destroyed with whole) --- %% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool KVCache *-- PagePool
KVCache *-- Storage KVCache *-- Storage
KVCache *-- TaskTable KVCache *-- TaskTable
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor InferenceScheduler *-- Executor
@ -963,7 +964,6 @@ classDiagram
DecoderBlock *-- RMSNorm DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry BaseFactory *-- Registry
BaseExecutor *-- GradientState BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState AccumOptimizer o-- GradientState
@ -971,6 +971,9 @@ classDiagram
%% --- Aggregation (weak ownership) --- %% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback Trainer o-- TrainCallback
TrainContext o-- BaseStrategy TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler TrainContext o-- BaseScheduler
@ -978,7 +981,7 @@ classDiagram
TrainContext o-- BaseExecutor TrainContext o-- BaseExecutor
KvcacheView o-- Storage KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage BaseDataset o-- Store
%% --- Dependency (uses temporarily) --- %% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects TrainConfig ..> BaseStrategy : selects
@ -992,12 +995,14 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses DecoderBlock ..> FFNFactory : uses
StorageFactory ..> H5Storage : creates StoreFactory ..> H5Store : creates
StorageFactory ..> JSONStorage : creates StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates TrainContextBuilder ..> TrainContext : creates
@ -1009,10 +1014,10 @@ classDiagram
KVCache ..> KvcacheView : binds KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates InferenceEngine ..> GenerateResult : creates
OpenAIHandler ..> ChatCompletionRequest : receives OpenAIResponseBuilder ..> ChatCompletionRequest : receives
AnthropicHandler ..> MessagesRequest : receives AnthropicResponseBuilder ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> StreamContext : creates ProtocolHandler ..> GenContext : creates
%% --- Association (general usage) --- %% --- Association (general usage) ---
Trainer --> TrainConfig Trainer --> TrainConfig
@ -1025,8 +1030,6 @@ classDiagram
Executor --> AutoModel Executor --> AutoModel
Executor --> AutoTokenizer Executor --> AutoTokenizer
TaskManager --> AutoTokenizer TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
``` ```
@ -1036,13 +1039,13 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization | | **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration | | **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers | | **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
@ -1050,17 +1053,17 @@ classDiagram
| Pattern | Classes | Purpose | | Pattern | Classes | Purpose |
|---------|---------|---------| |---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation | | **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks | | **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
| **Builder** | `TrainContextBuilder` | Chain-building training context | | **Builder** | `TrainContextBuilder` | Chain-building training context |
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | | **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1069,10 +1072,10 @@ classDiagram
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed) 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` 7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops

View File

@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview ## Overview
``` ```
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
``` ```
## Data Preparation ## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups. Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
``` ```
StorageFactory.create("h5") → H5Storage StoreFactory.create("h5") → H5Store
StorageFactory.create("json") → JSONStorage StoreFactory.create("json") → JSONStore
StoreFactory.create("bin") → MmapStore
``` ```
Both support shared memory via `.share_memory_()`. H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type ## Data Keys by Training Type
@ -33,14 +34,14 @@ Both support shared memory via `.share_memory_()`.
## Dataset Architecture ## Dataset Architecture
``` ```
DatasetFactory.load(train_type, path, window_size, stride) DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
→ StorageFactory.create(detect_format(path)) → StoreFactory.create(detect_format(path))
MultiSegmentFetcher(BaseSegmentFetcher per key) Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx) → BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx) → sliding window [begin, end) via get_index(idx)
``` ```
`window_size` = max input length, `stride` = step between consecutive samples. `window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
## Sampler ## Sampler

View File

@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage. `sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Template Method) ## Protocol Handlers (Strategy Pattern)
```python ```python
class ProtocolHandler(ABC): class ProtocolHandler: # concrete orchestrator
def handle(self): def handle(self, request):
ctx = StreamContext(...) prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...) agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx) if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx) else: self._handle_non_stream(agen, ctx, stops)
``` ```
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`. `ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIHandler``/v1/chat/completions`, `AnthropicHandler``/v1/messages`. `OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult ## Engine & GenerateResult
@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description | | Param | Type | Default | Description |
|-------|------|---------|-------------| |-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) | | `messages` | List[dict] | required | Chat messages (role, content) |
| `temperature` | float | 1.0 | Sampling temperature (0.02.0) | | `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold | | `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count | | `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length | | `max_tokens` | int | None | Max generation length |

View File

@ -53,7 +53,7 @@
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 | | `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none | | `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--device_type` | Device type | cuda | | `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn | | `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |

View File

@ -82,8 +82,7 @@ on_train_begin
on_optimizer_step on_optimizer_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
scheduler.step()
scheduler.step() # called every iteration
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
@ -190,7 +189,7 @@ context = (
``` ```
- Loads checkpoint weights if provided - Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)` - Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers - Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume - Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)` - Builds strategy via `StrategyFactory.create(train_type, ...)`

View File

@ -17,8 +17,8 @@ def required(**kw):
@dataclass @dataclass
class TrainConfig(BaseConfig): class TrainConfig(BaseConfig):
# basic setting # basic setting
model: nn.Module = field( model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model for training.") default=None, metadata=required(help="Model factory for training.")
) )
strategy: str = field(default=None, metadata=required(help="Training strategy.")) strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field( dataset: Dataset = field(

View File

@ -4,15 +4,17 @@ from astrai.dataset.dataset import (
) )
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import ( from astrai.dataset.storage import (
BaseSegmentFetcher, H5Store,
BaseStorage, JSONStore,
H5Storage, MmapStore,
JSONStorage, Store,
MultiSegmentFetcher, StoreFactory,
StorageFactory,
detect_format, detect_format,
json_to_bin,
load_bin,
load_h5, load_h5,
load_json, load_json,
save_bin,
save_h5, save_h5,
save_json, save_json,
) )
@ -20,16 +22,18 @@ from astrai.dataset.storage import (
__all__ = [ __all__ = [
"BaseDataset", "BaseDataset",
"DatasetFactory", "DatasetFactory",
"BaseSegmentFetcher", "Store",
"MultiSegmentFetcher", "StoreFactory",
"BaseStorage", "H5Store",
"H5Storage", "JSONStore",
"JSONStorage", "MmapStore",
"StorageFactory",
"detect_format", "detect_format",
"save_h5", "save_h5",
"load_h5", "load_h5",
"save_json", "save_json",
"load_json", "load_json",
"save_bin",
"load_bin",
"json_to_bin",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
BaseStorage, Store,
StorageFactory, StoreFactory,
detect_format, detect_format,
) )
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.storage: Optional[BaseStorage] = None self.storage: Optional[Store] = None
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
""" """
if storage_type is None: if storage_type is None:
storage_type = detect_format(load_path) storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type) self.storage = StoreFactory.create(storage_type)
self._load_path = load_path self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer) self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys() self._validate_keys()
@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
""" """
@classmethod @classmethod
def _validate_component(cls, dataset_cls: type) -> None: def _validate_component(cls, dataset_cls: type):
"""Validate that the dataset class inherits from BaseDataset.""" """Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset): if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")

View File

@ -1,7 +1,20 @@
"""Storage backends for different data formats. """Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides Layers:
a uniform interface for data access and length observation via fetchers. - I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
""" """
import bisect import bisect
@ -12,6 +25,7 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import h5py import h5py
import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
@ -104,6 +118,38 @@ def load_json(
return tensor_group return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json"))
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json"))
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str: def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory. """Auto-detect storage format from files in the directory.
@ -111,7 +157,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path load_path: Directory or file path
Returns: Returns:
Format string ("h5" or "json") Format string ("h5", "bin", or "json")
Raises: Raises:
FileNotFoundError: If no supported data files are found FileNotFoundError: If no supported data files are found
@ -128,166 +174,118 @@ def detect_format(load_path: str) -> str:
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files: if h5_files:
return "h5" return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files: if json_files:
return "json" return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}") raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher: class Store(ABC):
"""Fetches data segments across multiple tensor segments. """String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Maintains cumulative lengths for efficient range queries across Each key maps to one or more tensor segments (no forced concatenation).
multiple discontinuous segments. ``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
""" total element count across all keys.
def __init__(self, segments: List[Tensor]): Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
self.segments = segments via ``_normalize()``.
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
""" """
def __init__(self): def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod @abstractmethod
def load(self, load_path: str, tokenizer=None) -> None: def load(self, path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property @property
def keys(self) -> List[str]: def keys(self) -> List[str]:
"""Return the data keys available in this storage.""" return list(self._data.keys())
if self._fetcher is None:
return [] def __len__(self) -> int:
return self._fetcher.multi_keys return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
class StorageFactory(BaseFactory["BaseStorage"]): class StoreFactory(BaseFactory["Store"]):
"""Factory for creating storage backends by type name. """Factory for creating Store instances by type name.
Example: Example::
@StorageFactory.register("custom")
class CustomStorage(BaseStorage): @StoreFactory.register("custom")
class CustomStore(Store):
... ...
storage = StorageFactory.create("custom")
""" """
@classmethod @classmethod
def _validate_component(cls, storage_cls: type) -> None: def _validate_component(cls, store_cls: type):
if not issubclass(storage_cls, BaseStorage): if not issubclass(store_cls, Store):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage") raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StorageFactory.register("h5") @StoreFactory.register("h5")
class H5Storage(BaseStorage): class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data).""" """HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None: def load(self, path: str, tokenizer=None):
segments = load_h5(load_path) self._normalize(load_h5(path))
self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("json") @StoreFactory.register("json")
class JSONStorage(BaseStorage): class JSONStore(Store):
"""JSON-based storage backend. """JSON-based storage backend.
Supports two modes: Supports two modes:
@ -296,6 +294,28 @@ class JSONStorage(BaseStorage):
callable (str -> List[int]) at load time. callable (str -> List[int]) at load time.
""" """
def load(self, load_path: str, tokenizer=None) -> None: def load(self, path: str, tokenizer=None):
segments = load_json(load_path, tokenizer=tokenizer) self._normalize(load_json(path, tokenizer=tokenizer))
self._fetcher = MultiSegmentFetcher(segments)
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(path)
self._normalize(raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)

View File

@ -23,7 +23,7 @@ class Registry:
component_cls: Type, component_cls: Type,
category: Optional[str] = None, category: Optional[str] = None,
priority: int = 0, priority: int = 0,
) -> None: ):
"""Register a component class with optional category and priority.""" """Register a component class with optional category and priority."""
if name in self._entries: if name in self._entries:
raise ValueError(f"Component '{name}' is already registered") raise ValueError(f"Component '{name}' is already registered")
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
return component_cls(*args, **kwargs) return component_cls(*args, **kwargs)
@classmethod @classmethod
def _validate_component(cls, component_cls: Type[T]) -> None: def _validate_component(cls, component_cls: Type[T]):
"""Validate that the component class is valid for this factory. """Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation. Override this method in subclasses to add custom validation.

View File

@ -42,7 +42,7 @@ class Allocator:
return idx return idx
return -1 return -1
def free(self, idx: int, keep_cached: bool = False) -> None: def free(self, idx: int, keep_cached: bool = False):
with self._lock: with self._lock:
self._refs[idx] -= 1 self._refs[idx] -= 1
if self._refs[idx] == 0: if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else: else:
self._free_mask |= 1 << idx self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None: def inc_ref(self, idx: int):
with self._lock: with self._lock:
self._refs[idx] += 1 self._refs[idx] += 1
self._lru.pop(idx, None) self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock: with self._lock:
return self._refs[idx] return self._refs[idx]
def touch(self, idx: int) -> None: def touch(self, idx: int):
with self._lock: with self._lock:
self._lru.move_to_end(idx) self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def evict(self, idx: int) -> None: def evict(self, idx: int):
with self._lock: with self._lock:
h = self._page_to_hash.pop(idx, None) h = self._page_to_hash.pop(idx, None)
if h is not None: if h is not None:
@ -96,9 +96,7 @@ class PrefixCache:
hits.append(p) hits.append(p)
return hits return hits
def record( def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock: with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size) h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None) old_h = self._page_to_hash.pop(page_idx, None)
@ -127,13 +125,13 @@ class PagePool:
def alloc(self) -> int: def alloc(self) -> int:
return self._alloc.alloc() return self._alloc.alloc()
def free(self, idx: int) -> None: def free(self, idx: int):
keep = self._prefix.has_page(idx) keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep) self._alloc.free(idx, keep_cached=keep)
if not keep: if not keep:
self._prefix.evict(idx) self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None: def inc_ref(self, idx: int):
self._alloc.inc_ref(idx) self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]: def lookup(self, token_ids: List[int]) -> List[int]:
@ -142,9 +140,7 @@ class PagePool:
self._alloc.touch(p) self._alloc.touch(p)
return hits return hits
def record( def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx) self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -157,7 +153,7 @@ class TaskTable:
self._cached: Dict[str, int] = {} self._cached: Dict[str, int] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None: def set(self, task_id: str, page_table: List[int], cached: int):
with self._lock: with self._lock:
self._pages[task_id] = page_table self._pages[task_id] = page_table
self._cached[task_id] = cached self._cached[task_id] = cached
@ -220,7 +216,7 @@ class Storage:
start_pos: int, start_pos: int,
k: Tensor, k: Tensor,
v: Tensor, v: Tensor,
) -> None: ):
seq_len = k.size(1) seq_len = k.size(1)
if seq_len == 0: if seq_len == 0:
return return
@ -286,7 +282,7 @@ class KvcacheView:
self._page_table = page_table self._page_table = page_table
self._total_len = total_len self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None: def write(self, layer_id: int, k: Tensor, v: Tensor):
start_pos = self._total_len - k.size(1) start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v) self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -339,7 +335,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached) self._table.set(task_id, hits + new_pages, cached)
return True return True
def task_free(self, task_id: str) -> None: def task_free(self, task_id: str):
page_table, _ = self._table.pop(task_id) page_table, _ = self._table.pop(task_id)
for idx in page_table: for idx in page_table:
self._pool.free(idx) self._pool.free(idx)
@ -359,7 +355,7 @@ class KVCache:
def task_record_hashes( def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None: ):
page_table = self._table.get(task_id) page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages): for i in range(start_logical_page, full_pages):

View File

@ -29,9 +29,7 @@ class Executor:
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill( def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len: if start_pos >= prompt_len:
return return

View File

@ -75,14 +75,14 @@ class InferenceScheduler:
def add_task(self, prompt: str, **kwargs) -> str: def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str):
for task in self._task_mgr.remove_task(task_id): for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats() return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None: def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids stop_ids = self._task_mgr.tokenizer.stop_ids
try: try:
while self._running: while self._running:
@ -186,14 +186,14 @@ class InferenceScheduler:
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise raise
def start(self) -> None: def start(self):
if not self._running: if not self._running:
self._running = True self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True) t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start() t.start()
self._loop_thread = t self._loop_thread = t
def stop(self) -> None: def stop(self):
self._running = False self._running = False
self._task_mgr.wake() self._task_mgr.wake()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft()) to_add.append(self.waiting_queue.popleft())
return to_add return to_add
def activate(self, task: Task) -> None: def activate(self, task: Task):
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
with self._lock: with self._lock:
self.active_tasks.append(task) self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]) -> None: def return_to_waiting(self, tasks: List[Task]):
with self._lock: with self._lock:
for task in reversed(tasks): for task in reversed(tasks):
self.waiting_queue.appendleft(task) self.waiting_queue.appendleft(task)
@ -185,7 +185,7 @@ class TaskManager:
def has_work(self) -> bool: def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue) return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0) -> None: def wait_for_tasks(self, timeout: float = 1.0):
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)
@ -197,10 +197,10 @@ class TaskManager:
with self._lock: with self._lock:
return list(self.waiting_queue) return list(self.waiting_queue)
def clear_queues(self) -> None: def clear_queues(self):
with self._lock: with self._lock:
self.waiting_queue.clear() self.waiting_queue.clear()
self.active_tasks.clear() self.active_tasks.clear()
def wake(self) -> None: def wake(self):
self._task_event.set() self._task_event.set()

View File

@ -48,7 +48,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool: def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0) -> None: def wait_completion(self, timeout: float = 300.0):
with self._cond: with self._cond:
if not self._cond.wait_for( if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout lambda: self._completed >= self._total, timeout=timeout
@ -281,7 +281,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self):
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
@contextmanager @contextmanager
def _disable_random_init(enable: bool = True): def _disable_random_init(enable: bool = True):
init_functions = [ if not enable:
yield
return
names = (
"xavier_normal_", "xavier_normal_",
"xavier_uniform_", "xavier_uniform_",
"kaiming_normal_", "kaiming_normal_",
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
"constant_", "constant_",
"normal_", "normal_",
"uniform_", "uniform_",
] )
original_funcs = {} orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
for name in init_functions: for n in orig:
if enable and hasattr(nn.init, name): setattr(nn.init, n, lambda *a, **kw: None)
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try: try:
yield yield
finally: finally:
if enable: for n, fn in orig.items():
for name, orig_func in original_funcs.items(): setattr(nn.init, n, fn)
setattr(nn.init, name, orig_func)
class AutoModel(BaseFactory["AutoModel"], nn.Module): class AutoModel(BaseFactory["AutoModel"], nn.Module):
@ -82,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, Path], save_directory: Union[str, Path],
) -> None: ):
save_model( save_model(
config=self.config.to_dict(), config=self.config.to_dict(),
state_dict=self.state_dict(), state_dict=self.state_dict(),

View File

@ -68,9 +68,6 @@ class EmbeddingEncoder(AutoModel):
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids) rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False) attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)

View File

@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional from typing import Any, Dict, Mapping, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -136,7 +136,7 @@ class AutoRegressiveLM(AutoModel):
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None, paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None,
) -> Tensor: ) -> Dict[str, Tensor]:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)

View File

@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor):
@ExecutorFactory.register("fsdp") @ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor): class FSDPExecutor(BaseExecutor):
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs): def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps) super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = fsdp_kwargs self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module: def _prepare_model(self, model: nn.Module) -> nn.Module:

View File

@ -1,8 +1,9 @@
import io
import json import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict, List, Tuple
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -11,11 +12,11 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank
_META_FILE = "meta.json" _META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors" _WEIGHTS_FILE = "model.safetensors"
_MODEL_CONFIG_FILE = "config.json"
def save_safetensors(state_dict: dict, path: str | Path) -> None: def save_safetensors(state_dict: dict, path: str | Path):
st.save_file(state_dict, str(path)) st.save_file(state_dict, str(path))
@ -23,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path)) return st.load_file(str(path))
def save_json(data: dict, path: str | Path) -> None: def save_json(data: dict, path: str | Path):
with open(str(path), "w") as f: with open(str(path), "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
@ -33,13 +34,92 @@ def load_json(path: str | Path) -> dict:
return json.load(f) return json.load(f)
def save_torch(obj: Any, path: str | Path) -> None: def save_torch(obj: Any, path: str | Path):
torch.save(obj, str(path)) torch.save(obj, str(path))
def load_torch(path: str | Path) -> Any: def load_torch(path: str | Path, broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False) return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict:
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs: List[Tuple[str, List[int], str]] = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass @dataclass
class Checkpoint: class Checkpoint:
@ -49,7 +129,7 @@ class Checkpoint:
extra: Dict[str, Any] = field(default_factory=dict) extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str) -> None: def save(self, save_dir: str):
save_path = Path(save_dir) save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
@ -68,24 +148,16 @@ class Checkpoint:
save_torch(value, save_path / f"{key}.pt") save_torch(value, save_path / f"{key}.pt")
@classmethod @classmethod
def load(cls, save_dir: str) -> "Checkpoint": def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir) save_path = Path(save_dir)
meta = {} meta = _get_meta(save_path)
if get_rank() == 0: state_dict = _load_state_dict(save_path, broadcast=broadcast)
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
extra = {} extra = {}
for f in save_path.iterdir(): for f in sorted(save_path.iterdir()):
if f.suffix == ".pt": if f.suffix == ".pt":
extra[f.stem] = load_torch(f) extra[f.stem] = load_torch(f, broadcast=broadcast)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
@ -93,18 +165,3 @@ class Checkpoint:
iteration=meta.get("iteration", 0), iteration=meta.get("iteration", 0),
extra=extra, extra=extra,
) )
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _MODEL_CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
""" """
@classmethod @classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None: def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
"""Validate that the scheduler class inherits from BaseScheduler.""" """Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler): if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
""" """
@classmethod @classmethod
def _validate_component(cls, strategy_cls: type) -> None: def _validate_component(cls, strategy_cls: type):
"""Validate that the strategy class inherits from BaseStrategy.""" """Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy): if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device from astrai.parallel.setup import get_current_device, get_rank
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_grad_max, ctx_get_grad_max,
@ -139,27 +139,27 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join( # All ranks gather state_dict — collective for FSDP, local for DDP
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = ( state_dict = (
self.state_dict_fn(context.model) self.state_dict_fn(context.model)
if self.state_dict_fn if self.state_dict_fn
else context.model.state_dict() else context.model.state_dict()
) )
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context) extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, state_dict=state_dict,
@ -168,13 +168,7 @@ class CheckpointCallback(TrainCallback):
extra=extra, extra=extra,
meta=context.config.to_dict(), meta=context.config.to_dict(),
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval: if context.iteration - self.last_ckpt_iter >= self.interval:
@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback):
extra[name] = obj.state_dict() extra[name] = obj.state_dict()
return extra return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar") @CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self from typing import Optional, Self
import torch.nn as nn import torch.nn as nn
@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -42,10 +43,10 @@ class TrainContextBuilder:
config: TrainConfig, config: TrainConfig,
): ):
self.config = config self.config = config
self._checkpoint: Optional[Checkpoint] = None self._resume_dir: Optional[str] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._checkpoint = checkpoint self._resume_dir = resume_dir
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
@ -58,36 +59,40 @@ class TrainContextBuilder:
**cfg.executor_kwargs, **cfg.executor_kwargs,
) )
model = cfg.model_fn()
model = model.to(device=device)
context = TrainContext( context = TrainContext(
model=cfg.model, model=model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=cfg, config=cfg,
executor=executor, executor=executor,
) )
context.model = context.model.to(device=device) if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if self._checkpoint is not None: if (resume_path / "meta.json").exists():
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch) checkpoint = Checkpoint.load(self._resume_dir)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch) state_dict = checkpoint.state_dict
if self._checkpoint.state_dict:
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else: else:
context.checkpoint = Checkpoint( checkpoint = None
state_dict=context.model.state_dict(), state_dict = load_model_weights(self._resume_dir)
) model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
context.iteration = max(checkpoint.iteration, cfg.start_batch)
context.checkpoint = checkpoint
if cfg.lora is not None: if cfg.lora is not None:
inject_lora( inject_lora(
context.model, model,
r=cfg.lora.r, r=cfg.lora.r,
alpha=cfg.lora.alpha, alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules), target_modules=set(cfg.lora.target_modules),
) )
context.optimizer = cfg.optimizer_fn(context.model) context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer) context.scheduler = cfg.scheduler_fn(context.optimizer)
sampler_offset = context.iteration * cfg.batch_per_device sampler_offset = context.iteration * cfg.batch_per_device
@ -125,13 +130,21 @@ class TrainContextBuilder:
context.model, context.optimizer, context.dataloader, context.scheduler = ( context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare( executor.prepare(
context.model, model,
context.optimizer, context.optimizer,
context.dataloader, context.dataloader,
context.scheduler, context.scheduler,
) )
) )
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create( context.strategy = StrategyFactory.create(
model=context.model, model=context.model,
train_type=cfg.strategy, train_type=cfg.strategy,

View File

@ -3,7 +3,6 @@ from typing import List, Optional
from astrai.config import TrainConfig from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import ( from astrai.trainer.train_callback import (
CallbackFactory, CallbackFactory,
TrainCallback, TrainCallback,
@ -54,9 +53,9 @@ class Trainer:
if method: if method:
method(context) method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None): def _trainer_loop(self, resume_dir: Optional[str] = None):
context = ( context = (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build() TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
) )
executor = context.executor executor = context.executor
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
@ -90,13 +89,13 @@ class Trainer:
self._call_callbacks("on_epoch_end", context) self._call_callbacks("on_epoch_end", context)
except Exception as e: except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True) logger.error("Training failed: %s", str(e), exc_info=True)
self._call_callbacks("on_error", context) self._call_callbacks("on_error", context)
raise raise
finally: finally:
self._call_callbacks("on_train_end", context) self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None): def train(self, resume_dir: Optional[str] = None):
cfg = self.train_config cfg = self.train_config
spawn_parallel_fn( spawn_parallel_fn(
self._trainer_loop, self._trainer_loop,
@ -106,5 +105,5 @@ class Trainer:
master_port=cfg.master_port, master_port=cfg.master_port,
device_type=cfg.device_type, device_type=cfg.device_type,
start_method=cfg.start_method, start_method=cfg.start_method,
checkpoint=checkpoint, resume_dir=resume_dir,
) )

View File

@ -0,0 +1,279 @@
"""MMLU evaluation via log-likelihood ranking."""
import argparse
import csv
import json
import os
import shutil
import urllib.request
import zipfile
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _download_and_extract(url: str, data_dir: str):
zip_path = os.path.join(data_dir, "mmlu.zip")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
urllib.request.urlretrieve(url, zip_path)
print("Extracting...")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(data_dir)
os.remove(zip_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "test-master", "data")
if os.path.exists(src):
for item in os.listdir(src):
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
shutil.rmtree(os.path.join(data_dir, "test-master"))
print(f"MMLU data saved to {data_dir}")
def _strip_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :].strip()
return text
def load_csv(path: str) -> list[dict]:
data = []
with open(path, "r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 6:
continue
if row[0].strip().lower() == "question":
continue
data.append(
{
"question": row[0].strip(),
"A": _strip_prefix(row[1].strip(), "A)"),
"B": _strip_prefix(row[2].strip(), "B)"),
"C": _strip_prefix(row[3].strip(), "C)"),
"D": _strip_prefix(row[4].strip(), "D)"),
"answer": row[5].strip(),
}
)
return data
def build_prompt(
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
) -> str:
prompt = ""
if n_shot > 0 and dev_data:
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
for item in dev_data[:n_shot]:
prompt += f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {item[k]}\n"
prompt += f"Answer: {item['answer']}\n\n"
prompt += f"Question: {question}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {choices[k]}\n"
prompt += "Answer:"
return prompt
def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float:
choice_text = f" {choice_letter}"
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids
max_len = model.config.max_len
if len(input_ids) > max_len:
overflow = len(input_ids) - max_len
input_ids = input_ids[overflow:]
ctx_len = len(input_ids) - len(choice_ids)
else:
ctx_len = len(context_ids)
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits = model(input_tensor)["logits"][0]
score = 0.0
for i, tid in enumerate(choice_ids):
pos = ctx_len - 1 + i
if pos >= len(logits):
break
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
return score
def evaluate_subject(
model,
tokenizer,
subject: str,
test_data: list[dict],
dev_data: list[dict] | None,
device: str,
n_shot: int,
) -> tuple[float, int, int]:
correct = 0
total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or [])
context_ids = tokenizer.encode(prompt)
scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D")
}
if max(scores, key=scores.get) == item["answer"]:
correct += 1
total += 1
return correct / total, correct, total
def main():
parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
)
parser.add_argument("--download", action="store_true", help="Download MMLU data")
parser.add_argument(
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
)
parser.add_argument(
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
)
parser.add_argument("--output", type=str, help="Output JSON path")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16" if torch.cuda.is_available() else "float32",
help="Torch dtype",
)
args = parser.parse_args()
if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
subjects = args.subjects or MMLU_SUBJECTS
results = {}
total_correct = 0
total_questions = 0
for subject in subjects:
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
test_path = os.path.join(
args.data_dir, args.split, f"{subject}_{args.split}.csv"
)
if not os.path.exists(test_path):
print(f" Skipping {subject}: test file not found")
continue
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
test_data = load_csv(test_path)
acc, corr, tot = evaluate_subject(
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
)
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
total_correct += corr
total_questions += tot
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
overall = total_correct / total_questions if total_questions else 0
print(f"\n{'=' * 70}")
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
results["_overall"] = {
"accuracy": round(overall, 4),
"correct": total_correct,
"total": total_questions,
}
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file( def process_file(
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(model_dir) model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(param_path)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f: with open(input_file, "r", encoding="utf-8") as f:
@ -44,8 +44,8 @@ def process_file(
for seq in batch_encoded: for seq in batch_encoded:
pad_len = max_len - len(seq) pad_len = max_len - len(seq)
padded_seq = [tokenizer.pad_id] * pad_len + seq padded_seq = seq + [tokenizer.pad_id] * pad_len
mask = [False] * pad_len + [True] * len(seq) mask = [True] * len(seq) + [False] * pad_len
padded_ids.append(padded_seq) padded_ids.append(padded_seq)
masks.append(mask) masks.append(mask)
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument( parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory." "--param_path", type=str, required=True, help="Path to the model directory."
) )
parser.add_argument( parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file." "--input_file", type=str, required=True, help="Path to the input file."

View File

@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development" "--reload", action="store_true", help="Enable auto-reload for development"
) )
parser.add_argument( parser.add_argument(
"--param-path", "--param_path",
type=Path, type=Path,
default=None, default=None,
help="Path to model parameters (default: project_root/params)", help="Path to model parameters (default: project_root/params)",

View File

@ -8,7 +8,6 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -147,8 +146,8 @@ def parse_args() -> argparse.Namespace:
"--parallel_mode", "--parallel_mode",
type=str, type=str,
default="none", default="none",
choices=["none", "ddp"], choices=["none", "ddp", "fsdp"],
help="Parallel training strategy.", help="Parallel training strategy (none, ddp, fsdp).",
) )
parser.add_argument( parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use." "--device_type", type=str, default="cuda", help="Device type to use."
@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace:
return args return args
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def create_optimizer(model, **kwargs) -> optim.Optimizer: def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs) return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -228,6 +231,8 @@ def train(
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config # Load config
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")
@ -236,15 +241,6 @@ def train(
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
# (model weights already loaded into model above)
checkpoint.state_dict = {}
strategy_kwargs = { strategy_kwargs = {
"beta": dpo_beta, "beta": dpo_beta,
"label_smoothing": label_smoothing, "label_smoothing": label_smoothing,
@ -259,6 +255,7 @@ def train(
"broadcast_buffers": False, "broadcast_buffers": False,
} }
model_fn = partial(create_model, config)
dataset = DatasetFactory.load( dataset = DatasetFactory.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
@ -290,7 +287,7 @@ def train(
) )
train_config = TrainConfig( train_config = TrainConfig(
model=model, model_fn=model_fn,
strategy=train_type, strategy=train_type,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
@ -315,7 +312,7 @@ def train(
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train(checkpoint=checkpoint) trainer.train(resume_dir=param_path)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,10 +7,8 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
BaseSegmentFetcher, H5Store,
H5Storage, StoreFactory,
MultiSegmentFetcher,
StorageFactory,
detect_format, detect_format,
load_json, load_json,
save_h5, save_h5,
@ -318,37 +316,48 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0 assert len(dataset) == 0
def test_base_segment_fetcher_empty(): def test_store_unloaded_len():
"""BaseSegmentFetcher with empty segments list""" """Unloaded Store has __len__ == 0"""
fetcher = BaseSegmentFetcher([]) store = H5Store()
assert len(fetcher) == 0 assert len(store) == 0
with pytest.raises(ValueError, match="out of bounds"): assert store.keys == []
fetcher.fetch_data(0, 1)
def test_base_segment_fetcher_begin_equals_end(base_test_env): def test_store_fetch_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor""" """Store.fetch with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]} dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy) save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32) dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"] result = dataset.storage.fetch(10, 10, "sequence")
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0 assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict(): def test_store_empty_data_len(base_test_env):
"""MultiSegmentFetcher with empty dict has __len__ == 0""" """Store loaded with empty data has __len__ == 0"""
fetcher = MultiSegmentFetcher({}) import os
assert len(fetcher) == 0
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_storage_fetch_before_load(): def test_store_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError""" """Store.fetch before load raises RuntimeError"""
storage = H5Storage() store = H5Store()
with pytest.raises(RuntimeError, match="not loaded"): with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence") store.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path(): def test_detect_format_nonexistent_path():
@ -367,10 +376,10 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path) detect_format(path)
def test_create_storage_invalid_type(): def test_create_store_invalid_type():
"""StorageFactory.create raises ValueError for unknown type""" """StoreFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"): with pytest.raises(ValueError, match="Unknown component"):
StorageFactory.create("parquet") StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env): def test_json_pretokenized_without_tokenizer(base_test_env):
@ -407,14 +416,23 @@ def test_load_json_skips_config_file(base_test_env):
assert len(result["sequence"]) == 1 assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment(): def test_store_multi_segment_concat(base_test_env):
"""fetch_data across multiple segment boundaries""" """Multi-segment H5 data is concatenated into single tensor at load time"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
segs = [ segs = [
torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]), torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]), torch.tensor([8, 9]),
] ]
fetcher = BaseSegmentFetcher(segs) save_h5(data_dir, "data", {"sequence": segs})
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7) store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7] assert result.tolist() == [3, 4, 5, 6, 7]

View File

@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
def create_train_config( def create_train_config(
model: torch.nn.Module, model_fn,
dataset: Dataset, dataset: Dataset,
test_dir: str, test_dir: str,
device: str, device: str,
@ -43,7 +43,7 @@ def create_train_config(
"""Factory function to create common TrainConfig for tests. """Factory function to create common TrainConfig for tests.
Args: Args:
model: The model to train model_fn: Model factory (callable returning nn.Module)
dataset: Training dataset dataset: Training dataset
test_dir: Checkpoint directory test_dir: Checkpoint directory
device: Device type ("cuda" or "cpu") device: Device type ("cuda" or "cpu")
@ -70,7 +70,7 @@ def create_train_config(
return TrainConfig( return TrainConfig(
strategy=strategy, strategy=strategy,
model=model, model_fn=model_fn,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,

View File

@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
) )
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
strategy="seq", strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
) )
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
strategy="seq", strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,

View File

@ -4,7 +4,6 @@ import numpy as np
import torch import torch
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer from astrai.trainer.trainer import Trainer
@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
strategy="seq", strategy="seq",
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"), log_dir=os.path.join(base_test_env["test_dir"], "logs"),
@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
trainer = Trainer(train_config) trainer = Trainer(train_config)
# Should handle early stopping gracefully # Should handle early stopping gracefully
checkpoint = None
try: try:
checkpoint = trainer.train() trainer.train()
except Exception: except Exception:
# Handle any exceptions
pass pass
# Resume from latest checkpoint
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir) trainer = Trainer(train_config)
trainer.train(checkpoint) trainer.train(resume_dir=load_dir)
# Verify checkpoint was saved at expected iteration
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir) import json
assert checkpoint.iteration == 10
with open(os.path.join(load_dir, "meta.json")) as f:
meta = json.load(f)
assert meta["iteration"] == 10

View File

@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
for batch_per_device in batch_sizes: for batch_per_device in batch_sizes:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
for grad_accum_steps in grad_accum_steps_list: for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
for config in small_batch_configs: for config in small_batch_configs:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],