Compare commits

...

9 Commits

Author SHA1 Message Date
ViperEkura 0a708fff24 docs : 更新架构文档与 storage 注释,同步 Store 重构
- architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore)
- architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系
- dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构
- storage.py: module docstring 改用缩进式注释风格
2026-05-28 14:36:18 +08:00
ViperEkura 6e150ea6d0 refactor : Storage 层重构为 Store,移除 Fetcher 中间层,支持多段数据与显式长度
- 合并 BaseStorage + MultiSegmentFetcher + BaseSegmentFetcher 三层为 Store ABC
- Store._data 直接持有 Dict[str, List[Tensor]],不做强制拼接避免 OOM
- _fetch_key 统一用 bisect 跨段切片,单段多段同一路径
- _length 显式存储(min total across keys),__len__ 返回 O(1)
- MmapStore/H5Store/JSONStore 统一走 _normalize() 注册分段并预计算累积长度
- 所有 I/O 函数 (save_h5/load_h5/json_to_bin 等) 保持不变
2026-05-28 14:23:49 +08:00
ViperEkura cb8dcb97ea refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储
- 删除所有 def 函数 -> None 返回值类型标注
- FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤
- 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载
- 新增 save_bin/load_bin/json_to_bin 工具函数
- detect_format 支持 bin 格式自动检测
2026-05-28 13:57:06 +08:00
ViperEkura 2d5dc93b3d fix : 修正类型标注与统一 CLI 参数命名
- AutoRegressiveLM.forward 返回类型标注 -> Dict[str, Tensor]
- EmbeddingEncoder 移除冗余 position_ids 自动创建
- CLI 脚本模型目录参数统一为 --param_path
2026-05-27 20:49:44 +08:00
ViperEkura 4145d35e3c refactor: 检查点加载重构,路径替代对象传递
- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串
- Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递
- TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch
- CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁
- serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast
- optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行
- pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
2026-05-27 20:15:29 +08:00
ViperEkura 34c6c45bd6 feat: 初步实现 MMLU 评测脚本
- 支持 few-shot (log-likelihood ranking) 与 zero-shot
- 自动下载 Hendrycks MMLU 数据集
- --device / --dtype 可配置,默认 GPU bf16
2026-05-26 20:23:31 +08:00
ViperEkura e9def84ce7 fix : perplexity.py left padding 导致 batch>1 时 PPL 计算错误 2026-05-26 19:59:57 +08:00
ViperEkura 836e02a166 docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项
- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式
- 修正训练循环 scheduler.step() 在 sync_gradients 块内
- 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联
- --parallel_mode CLI choices 加入 fsdp
- nprocs > 1 且 parallel_mode=none 时 raise error
2026-05-26 19:37:00 +08:00
ViperEkura b558e61f63 refactor: 简化 _disable_random_init,scheduler 移入同步块
- _disable_random_init: enable=False 提前返回,dict 推导替代空字典
- scheduler.step() 移入 sync_gradients 守卫内
2026-05-26 17:05:25 +08:00
34 changed files with 885 additions and 475 deletions

View File

@ -22,7 +22,8 @@ classDiagram
+int n_layers
+float norm_eps
+int dim_ffn
+bool tie_weight
+Optional[bool] tie_weight
+Optional[dict] rope_scaling
+int max_len
+float rope_theta
+str attn_type
@ -52,6 +53,7 @@ classDiagram
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[dict] rope_scaling
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
@ -80,6 +82,7 @@ classDiagram
+str log_dir
+int log_interval
+List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
@ -104,7 +107,7 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+Optional[BaseStorage] storage
+Optional[Store] storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
@ -126,38 +129,29 @@ classDiagram
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
class Store {
+Dict[str, List[Tensor]] _data
+Dict[str, List[int]] _cum
+int _length
+keys (property)
+load(load_path, tokenizer)
+load(path, tokenizer)
+fetch(begin, end, keys)
+__len__()
-_fetch_key(key, begin, end) Tensor
-_normalize(raw)
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
class H5Store {
+load(path, tokenizer)
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
class JSONStore {
+load(path, tokenizer)
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
class MmapStore {
+List _mmap_refs
+load(path, tokenizer)
}
class ResumableDistributedSampler {
@ -165,10 +159,10 @@ classDiagram
+int iter
}
class StorageFactory {
class StoreFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) BaseStorage
+create(storage_type) Store
}
class DatasetFactory {
@ -457,16 +451,15 @@ classDiagram
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_optimizer_step(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
+on_optimizer_step(context)
}
class GradientCheckpointingCallback {
@ -512,7 +505,7 @@ classDiagram
class ValidationCallback {
+_run_validation(context)
+on_step_end(context)
+on_optimizer_step(context)
}
class CallbackFactory {
@ -747,56 +740,58 @@ classDiagram
+str model
+List[AnthropicMessage] messages
+Optional[str] system
+float temperature
+float top_p
+int top_k
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+int max_tokens
+bool stream
+Optional[bool] stream
+Optional[List[str]] stop_sequences
}
class ProtocolHandler {
class ResponseBuilder {
<<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler {
+request
+engine
+build_prompt() str
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+builder: ResponseBuilder
+handle() Union[StreamingResponse, Dict]
}
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
-_handle_stream(agen, ctx, stops) StreamingResponse
-_handle_non_stream(agen, ctx, stops) Dict
}
class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str]
+trim(text, matched) str
}
class StreamContext {
class GenContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
}
class app {
@ -876,6 +871,11 @@ classDiagram
+unwrap_model(model) nn.Module
}
class FSDPExecutor {
+_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module
}
class ExecutorFactory {
+Registry _registry
+register(name) decorator
@ -911,12 +911,14 @@ classDiagram
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
Store <|-- H5Store
Store <|-- JSONStore
Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
@ -936,20 +938,19 @@ classDiagram
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- StoreFactory
BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
BaseExecutor <|-- FSDPExecutor
ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
@ -963,7 +964,6 @@ classDiagram
DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
@ -971,6 +971,9 @@ classDiagram
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
@ -978,7 +981,7 @@ classDiagram
TrainContext o-- BaseExecutor
KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage
BaseDataset o-- Store
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
@ -992,12 +995,14 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
StoreFactory ..> H5Store : creates
StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
@ -1009,10 +1014,10 @@ classDiagram
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
OpenAIHandler ..> ChatCompletionRequest : receives
AnthropicHandler ..> MessagesRequest : receives
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
AnthropicResponseBuilder ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> StreamContext : creates
ProtocolHandler ..> GenContext : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
@ -1025,8 +1030,6 @@ classDiagram
Executor --> AutoModel
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
```
@ -1036,13 +1039,13 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
@ -1050,17 +1053,17 @@ classDiagram
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
| **Builder** | `TrainContextBuilder` | Chain-building training context |
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1069,10 +1072,10 @@ classDiagram
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops

View File

@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview
```
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
```
## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
StorageFactory.create("h5") → H5Storage
StorageFactory.create("json") → JSONStorage
StoreFactory.create("h5") → H5Store
StoreFactory.create("json") → JSONStore
StoreFactory.create("bin") → MmapStore
```
Both support shared memory via `.share_memory_()`.
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type
@ -33,14 +34,14 @@ Both support shared memory via `.share_memory_()`.
## Dataset Architecture
```
DatasetFactory.load(train_type, path, window_size, stride)
→ StorageFactory.create(detect_format(path))
MultiSegmentFetcher(BaseSegmentFetcher per key)
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
→ StoreFactory.create(detect_format(path))
Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx)
```
`window_size` = max input length, `stride` = step between consecutive samples.
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
## Sampler

View File

@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Template Method)
## Protocol Handlers (Strategy Pattern)
```python
class ProtocolHandler(ABC):
def handle(self):
ctx = StreamContext(...)
class ProtocolHandler: # concrete orchestrator
def handle(self, request):
prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx)
else: self._handle_non_stream(agen, ctx)
if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx, stops)
```
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIHandler``/v1/chat/completions`, `AnthropicHandler``/v1/messages`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult
@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `temperature` | float | 1.0 | Sampling temperature (0.02.0) |
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length |

View File

@ -53,7 +53,7 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |

View File

@ -82,8 +82,7 @@ on_train_begin
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
scheduler.step() # called every iteration
scheduler.step()
on_epoch_end
on_train_end
```
@ -190,7 +189,7 @@ context = (
```
- Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)`

View File

@ -17,8 +17,8 @@ def required(**kw):
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(
default=None, metadata=required(help="Model for training.")
model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model factory for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(

View File

@ -4,15 +4,17 @@ from astrai.dataset.dataset import (
)
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MultiSegmentFetcher,
StorageFactory,
H5Store,
JSONStore,
MmapStore,
Store,
StoreFactory,
detect_format,
json_to_bin,
load_bin,
load_h5,
load_json,
save_bin,
save_h5,
save_json,
)
@ -20,16 +22,18 @@ from astrai.dataset.storage import (
__all__ = [
"BaseDataset",
"DatasetFactory",
"BaseSegmentFetcher",
"MultiSegmentFetcher",
"BaseStorage",
"H5Storage",
"JSONStorage",
"StorageFactory",
"Store",
"StoreFactory",
"H5Store",
"JSONStore",
"MmapStore",
"detect_format",
"save_h5",
"load_h5",
"save_json",
"load_json",
"save_bin",
"load_bin",
"json_to_bin",
"ResumableDistributedSampler",
]

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
StorageFactory,
Store,
StoreFactory,
detect_format,
)
from astrai.factory import BaseFactory
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
super().__init__()
self.window_size = window_size
self.stride = stride
self.storage: Optional[BaseStorage] = None
self.storage: Optional[Store] = None
@property
def required_keys(self) -> List[str]:
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type)
self.storage = StoreFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys()
@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
"""
@classmethod
def _validate_component(cls, dataset_cls: type) -> None:
def _validate_component(cls, dataset_cls: type):
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")

View File

@ -1,7 +1,20 @@
"""Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
"""
import bisect
@ -12,6 +25,7 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import h5py
import numpy as np
import torch
from torch import Tensor
@ -104,6 +118,38 @@ def load_json(
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json"))
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json"))
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
@ -111,7 +157,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path
Returns:
Format string ("h5" or "json")
Format string ("h5", "bin", or "json")
Raises:
FileNotFoundError: If no supported data files are found
@ -128,166 +174,118 @@ def detect_format(load_path: str) -> str:
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
"""
def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod
def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
def load(self, path: str, tokenizer=None) -> None:
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StorageFactory.register("h5")
class H5Storage(BaseStorage):
@StoreFactory.register("h5")
class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
def load(self, path: str, tokenizer=None):
self._normalize(load_h5(path))
@StorageFactory.register("json")
class JSONStorage(BaseStorage):
@StoreFactory.register("json")
class JSONStore(Store):
"""JSON-based storage backend.
Supports two modes:
@ -296,6 +294,28 @@ class JSONStorage(BaseStorage):
callable (str -> List[int]) at load time.
"""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
def load(self, path: str, tokenizer=None):
self._normalize(load_json(path, tokenizer=tokenizer))
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(path)
self._normalize(raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)

View File

@ -23,7 +23,7 @@ class Registry:
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
) -> None:
):
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]) -> None:
def _validate_component(cls, component_cls: Type[T]):
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.

View File

@ -42,7 +42,7 @@ class Allocator:
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
def free(self, idx: int, keep_cached: bool = False):
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
def inc_ref(self, idx: int):
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock:
return self._refs[idx]
def touch(self, idx: int) -> None:
def touch(self, idx: int):
with self._lock:
self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int) -> None:
def evict(self, idx: int):
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
@ -96,9 +96,7 @@ class PrefixCache:
hits.append(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
@ -127,13 +125,13 @@ class PagePool:
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int) -> None:
def free(self, idx: int):
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None:
def inc_ref(self, idx: int):
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
@ -142,9 +140,7 @@ class PagePool:
self._alloc.touch(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -157,7 +153,7 @@ class TaskTable:
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
def set(self, task_id: str, page_table: List[int], cached: int):
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
@ -220,7 +216,7 @@ class Storage:
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
):
seq_len = k.size(1)
if seq_len == 0:
return
@ -286,7 +282,7 @@ class KvcacheView:
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
def write(self, layer_id: int, k: Tensor, v: Tensor):
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -339,7 +335,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
def task_free(self, task_id: str):
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
@ -359,7 +355,7 @@ class KVCache:
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
):
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):

View File

@ -29,9 +29,7 @@ class Executor:
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
if start_pos >= prompt_len:
return

View File

@ -75,14 +75,14 @@ class InferenceScheduler:
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None:
def remove_task(self, task_id: str):
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
@ -186,14 +186,14 @@ class InferenceScheduler:
self._task_mgr.clear_queues()
raise
def start(self) -> None:
def start(self):
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self) -> None:
def stop(self):
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft())
return to_add
def activate(self, task: Task) -> None:
def activate(self, task: Task):
task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]) -> None:
def return_to_waiting(self, tasks: List[Task]):
with self._lock:
for task in reversed(tasks):
self.waiting_queue.appendleft(task)
@ -185,7 +185,7 @@ class TaskManager:
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0) -> None:
def wait_for_tasks(self, timeout: float = 1.0):
self._task_event.clear()
self._task_event.wait(timeout=timeout)
@ -197,10 +197,10 @@ class TaskManager:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self) -> None:
def clear_queues(self):
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self) -> None:
def wake(self):
self._task_event.set()

View File

@ -48,7 +48,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0) -> None:
def wait_completion(self, timeout: float = 300.0):
with self._cond:
if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
@ -281,7 +281,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats()
def shutdown(self) -> None:
def shutdown(self):
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
@contextmanager
def _disable_random_init(enable: bool = True):
init_functions = [
if not enable:
yield
return
names = (
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
)
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
for n in orig:
setattr(nn.init, n, lambda *a, **kw: None)
try:
yield
finally:
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
for n, fn in orig.items():
setattr(nn.init, n, fn)
class AutoModel(BaseFactory["AutoModel"], nn.Module):
@ -82,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained(
self,
save_directory: Union[str, Path],
) -> None:
):
save_model(
config=self.config.to_dict(),
state_dict=self.state_dict(),

View File

@ -68,9 +68,6 @@ class EmbeddingEncoder(AutoModel):
x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)

View File

@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional
from typing import Any, Dict, Mapping, Optional
import torch
import torch.nn as nn
@ -136,7 +136,7 @@ class AutoRegressiveLM(AutoModel):
input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
) -> Dict[str, Tensor]:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)

View File

@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor):
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = fsdp_kwargs
self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:

View File

@ -1,8 +1,9 @@
import io
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, List, Tuple
import safetensors.torch as st
import torch
@ -11,11 +12,11 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
_MODEL_CONFIG_FILE = "config.json"
def save_safetensors(state_dict: dict, path: str | Path) -> None:
def save_safetensors(state_dict: dict, path: str | Path):
st.save_file(state_dict, str(path))
@ -23,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path))
def save_json(data: dict, path: str | Path) -> None:
def save_json(data: dict, path: str | Path):
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
@ -33,13 +34,92 @@ def load_json(path: str | Path) -> dict:
return json.load(f)
def save_torch(obj: Any, path: str | Path) -> None:
def save_torch(obj: Any, path: str | Path):
torch.save(obj, str(path))
def load_torch(path: str | Path) -> Any:
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict:
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs: List[Tuple[str, List[int], str]] = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass
class Checkpoint:
@ -49,7 +129,7 @@ class Checkpoint:
extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str) -> None:
def save(self, save_dir: str):
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
@ -68,24 +148,16 @@ class Checkpoint:
save_torch(value, save_path / f"{key}.pt")
@classmethod
def load(cls, save_dir: str) -> "Checkpoint":
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir)
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
meta = _get_meta(save_path)
state_dict = _load_state_dict(save_path, broadcast=broadcast)
extra = {}
for f in save_path.iterdir():
for f in sorted(save_path.iterdir()):
if f.suffix == ".pt":
extra[f.stem] = load_torch(f)
extra[f.stem] = load_torch(f, broadcast=broadcast)
return cls(
state_dict=state_dict,
@ -93,18 +165,3 @@ class Checkpoint:
iteration=meta.get("iteration", 0),
extra=extra,
)
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _MODEL_CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
"""
@classmethod
def _validate_component(cls, strategy_cls: type) -> None:
def _validate_component(cls, strategy_cls: type):
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device
from astrai.parallel.setup import get_current_device, get_rank
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
@ -139,27 +139,27 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
# All ranks gather state_dict — collective for FSDP, local for DDP
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
@ -168,13 +168,7 @@ class CheckpointCallback(TrainCallback):
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback):
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self
import torch.nn as nn
@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint
from astrai.serialization import Checkpoint, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -42,10 +43,10 @@ class TrainContextBuilder:
config: TrainConfig,
):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
self._resume_dir: Optional[str] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._resume_dir = resume_dir
return self
def build(self) -> TrainContext:
@ -58,36 +59,40 @@ class TrainContextBuilder:
**cfg.executor_kwargs,
)
model = cfg.model_fn()
model = model.to(device=device)
context = TrainContext(
model=cfg.model,
model=model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
executor=executor,
)
context.model = context.model.to(device=device)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
if self._checkpoint.state_dict:
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict
else:
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
)
checkpoint = None
state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
context.iteration = max(checkpoint.iteration, cfg.start_batch)
context.checkpoint = checkpoint
if cfg.lora is not None:
inject_lora(
context.model,
model,
r=cfg.lora.r,
alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules),
)
context.optimizer = cfg.optimizer_fn(context.model)
context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
sampler_offset = context.iteration * cfg.batch_per_device
@ -125,13 +130,21 @@ class TrainContextBuilder:
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
context.model,
model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
model=context.model,
train_type=cfg.strategy,

View File

@ -3,7 +3,6 @@ from typing import List, Optional
from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
@ -54,9 +53,9 @@ class Trainer:
if method:
method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
def _trainer_loop(self, resume_dir: Optional[str] = None):
context = (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
)
executor = context.executor
self._call_callbacks("on_train_begin", context)
@ -90,13 +89,13 @@ class Trainer:
self._call_callbacks("on_epoch_end", context)
except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True)
logger.error("Training failed: %s", str(e), exc_info=True)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None):
def train(self, resume_dir: Optional[str] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
@ -106,5 +105,5 @@ class Trainer:
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
resume_dir=resume_dir,
)

View File

@ -0,0 +1,279 @@
"""MMLU evaluation via log-likelihood ranking."""
import argparse
import csv
import json
import os
import shutil
import urllib.request
import zipfile
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _download_and_extract(url: str, data_dir: str):
zip_path = os.path.join(data_dir, "mmlu.zip")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
urllib.request.urlretrieve(url, zip_path)
print("Extracting...")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(data_dir)
os.remove(zip_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "test-master", "data")
if os.path.exists(src):
for item in os.listdir(src):
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
shutil.rmtree(os.path.join(data_dir, "test-master"))
print(f"MMLU data saved to {data_dir}")
def _strip_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :].strip()
return text
def load_csv(path: str) -> list[dict]:
data = []
with open(path, "r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 6:
continue
if row[0].strip().lower() == "question":
continue
data.append(
{
"question": row[0].strip(),
"A": _strip_prefix(row[1].strip(), "A)"),
"B": _strip_prefix(row[2].strip(), "B)"),
"C": _strip_prefix(row[3].strip(), "C)"),
"D": _strip_prefix(row[4].strip(), "D)"),
"answer": row[5].strip(),
}
)
return data
def build_prompt(
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
) -> str:
prompt = ""
if n_shot > 0 and dev_data:
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
for item in dev_data[:n_shot]:
prompt += f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {item[k]}\n"
prompt += f"Answer: {item['answer']}\n\n"
prompt += f"Question: {question}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {choices[k]}\n"
prompt += "Answer:"
return prompt
def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float:
choice_text = f" {choice_letter}"
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids
max_len = model.config.max_len
if len(input_ids) > max_len:
overflow = len(input_ids) - max_len
input_ids = input_ids[overflow:]
ctx_len = len(input_ids) - len(choice_ids)
else:
ctx_len = len(context_ids)
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits = model(input_tensor)["logits"][0]
score = 0.0
for i, tid in enumerate(choice_ids):
pos = ctx_len - 1 + i
if pos >= len(logits):
break
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
return score
def evaluate_subject(
model,
tokenizer,
subject: str,
test_data: list[dict],
dev_data: list[dict] | None,
device: str,
n_shot: int,
) -> tuple[float, int, int]:
correct = 0
total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or [])
context_ids = tokenizer.encode(prompt)
scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D")
}
if max(scores, key=scores.get) == item["answer"]:
correct += 1
total += 1
return correct / total, correct, total
def main():
parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
)
parser.add_argument("--download", action="store_true", help="Download MMLU data")
parser.add_argument(
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
)
parser.add_argument(
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
)
parser.add_argument("--output", type=str, help="Output JSON path")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16" if torch.cuda.is_available() else "float32",
help="Torch dtype",
)
args = parser.parse_args()
if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
subjects = args.subjects or MMLU_SUBJECTS
results = {}
total_correct = 0
total_questions = 0
for subject in subjects:
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
test_path = os.path.join(
args.data_dir, args.split, f"{subject}_{args.split}.csv"
)
if not os.path.exists(test_path):
print(f" Skipping {subject}: test file not found")
continue
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
test_data = load_csv(test_path)
acc, corr, tot = evaluate_subject(
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
)
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
total_correct += corr
total_questions += tot
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
overall = total_correct / total_questions if total_questions else 0
print(f"\n{'=' * 70}")
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
results["_overall"] = {
"accuracy": round(overall, 4),
"correct": total_correct,
"total": total_questions,
}
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file(
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
):
# Load model and tokenizer
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(param_path)
model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f:
@ -44,8 +44,8 @@ def process_file(
for seq in batch_encoded:
pad_len = max_len - len(seq)
padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [False] * pad_len + [True] * len(seq)
padded_seq = seq + [tokenizer.pad_id] * pad_len
mask = [True] * len(seq) + [False] * pad_len
padded_ids.append(padded_seq)
masks.append(mask)
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory."
"--param_path", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file."

View File

@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser.add_argument(
"--param-path",
"--param_path",
type=Path,
default=None,
help="Path to model parameters (default: project_root/params)",

View File

@ -8,7 +8,6 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer
@ -147,8 +146,8 @@ def parse_args() -> argparse.Namespace:
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp"],
help="Parallel training strategy.",
choices=["none", "ddp", "fsdp"],
help="Parallel training strategy (none, ddp, fsdp).",
)
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace:
return args
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -228,6 +231,8 @@ def train(
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config
config_path = os.path.join(param_path, "config.json")
@ -236,15 +241,6 @@ def train(
if window_size is None:
window_size = config.max_len
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
# (model weights already loaded into model above)
checkpoint.state_dict = {}
strategy_kwargs = {
"beta": dpo_beta,
"label_smoothing": label_smoothing,
@ -259,6 +255,7 @@ def train(
"broadcast_buffers": False,
}
model_fn = partial(create_model, config)
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
@ -290,7 +287,7 @@ def train(
)
train_config = TrainConfig(
model=model,
model_fn=model_fn,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
@ -315,7 +312,7 @@ def train(
)
trainer = Trainer(train_config)
trainer.train(checkpoint=checkpoint)
trainer.train(resume_dir=param_path)
if __name__ == "__main__":

View File

@ -7,10 +7,8 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
StorageFactory,
H5Store,
StoreFactory,
detect_format,
load_json,
save_h5,
@ -318,37 +316,48 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_store_unloaded_len():
"""Unloaded Store has __len__ == 0"""
store = H5Store()
assert len(store) == 0
assert store.keys == []
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
def test_store_fetch_begin_equals_end(base_test_env):
"""Store.fetch with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
result = dataset.storage.fetch(10, 10, "sequence")
assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_store_empty_data_len(base_test_env):
"""Store loaded with empty data has __len__ == 0"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError"""
store = H5Store()
with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence")
store.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
@ -367,10 +376,10 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path)
def test_create_storage_invalid_type():
"""StorageFactory.create raises ValueError for unknown type"""
def test_create_store_invalid_type():
"""StoreFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StorageFactory.create("parquet")
StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
@ -407,14 +416,23 @@ def test_load_json_skips_config_file(base_test_env):
assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
save_h5(data_dir, "data", {"sequence": segs})
store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7]

View File

@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
def create_train_config(
model: torch.nn.Module,
model_fn,
dataset: Dataset,
test_dir: str,
device: str,
@ -43,7 +43,7 @@ def create_train_config(
"""Factory function to create common TrainConfig for tests.
Args:
model: The model to train
model_fn: Model factory (callable returning nn.Module)
dataset: Training dataset
test_dir: Checkpoint directory
device: Device type ("cuda" or "cpu")
@ -70,7 +70,7 @@ def create_train_config(
return TrainConfig(
strategy=strategy,
model=model,
model_fn=model_fn,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,

View File

@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
)
train_config = TrainConfig(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
)
train_config = TrainConfig(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,

View File

@ -4,7 +4,6 @@ import numpy as np
import torch
from astrai.config.train_config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer
@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
strategy="seq",
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
trainer = Trainer(train_config)
# Should handle early stopping gracefully
checkpoint = None
try:
checkpoint = trainer.train()
trainer.train()
except Exception:
# Handle any exceptions
pass
# Resume from latest checkpoint
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
trainer = Trainer(train_config)
trainer.train(resume_dir=load_dir)
# Verify checkpoint was saved at expected iteration
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10
import json
with open(os.path.join(load_dir, "meta.json")) as f:
meta = json.load(f)
assert meta["iteration"] == 10

View File

@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
for batch_per_device in batch_sizes:
train_config = train_config_factory(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
for config in small_batch_configs:
train_config = train_config_factory(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],