Compare commits

...

7 Commits

Author SHA1 Message Date
ViperEkura b37c3d000c docs : 同步文档与实际代码
- 移除 JSONStore 引用(该类不存在)
- 修正 Store.load() 和 DatasetFactory.load() 签名(无 tokenizer 参数)
- 修正 TrainContextBuilder.with_resume_dir() 命名
- 修正 Checkpoint config 字段和 meta.json 描述
- 修正 ProtocolHandler.handle() 异步签名
- 修正采样继承图(平行子类,非线性)
- 修正训练循环:回调移入 accumulate 块内
- 更新文档日期至 2026-05-28
2026-05-28 21:01:47 +08:00
ViperEkura 6031020e37 feat : load_json/load_safetensors 支持 broadcast,跨节点分布式加载
- load_json/load_safetensors/load_state_dict 新增 broadcast 参数
- broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank
- load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈
- 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json
- 参数注解 str | Path 统一为 Union[str, Path]
2026-05-28 20:44:58 +08:00
ViperEkura c424dfc293 feat : checkpoint 支持保存 config.json
- Checkpoint.save 写入独立的 config.json(模型架构参数)
- Checkpoint.load 读取 config.json,恢复时覆盖 context.model_config
- TrainContext 新增 model_config 字段,builder 从 resume_dir/config.json 加载
- BaseConfig.to_dict 支持 tuple 和嵌套 dataclass(如 LoRAConfig)
- 删除 _get_meta/_get_config wrapper,直接使用 load_json
2026-05-28 20:21:51 +08:00
ViperEkura 3a28e52e98 fix : start_epoch/start_batch 由用户参数决定,不再被 checkpoint 覆盖 2026-05-28 18:24:22 +08:00
ViperEkura e371908b54 fix : 保存 checkpoint 时 unwrap DDP/FSDP 避免 module. 前缀
- 移除 state_dict_fn 参数
- _save_checkpoint 中先 unwrap_model 再 state_dict()
2026-05-28 18:10:04 +08:00
ViperEkura 7c99da155c refactor: 删除数据流中的 JSONStore
- 移除 JSONStore 及相关函数,训练框架不再依赖 tokenizer
- Store 层只保留 H5Store 和 MmapStore 两种后端
2026-05-28 15:54:26 +08:00
ViperEkura 629e72385b fix : 修复存储层 bug,JSON 切换为 JSONL,补齐测试覆盖
- save_bin/load_bin: save_json/load_json 替换为直接 json.dump/json.load,修复致命 bug
- _normalize: 空 cum 列表 guard,防止 IndexError
- load_json: 改为仅支持 JSONL 逐行解析 (json.loads),移除 .json 支持
- detect_format: 只匹配 *.jsonl,不再匹配 *.json
- save_json: 输出扩展名改为 .jsonl
- GRPODataset.__getitem__: 补齐 .to(dtype=torch.long/bool) 与其他数据集一致
- load_bin: np.memmap mode='r+' 消除 PyTorch 不可写 tensor 警告
- 新增 16 个测试: bin roundtrip, mmap load, 空 key, JSONL 多行/文本, GRPO dtype/load, detect_format bin/jsonl, fetch multi-key/越界, json_to_bin 转换, DPO from JSONL, 显式 storage_type
2026-05-28 15:29:46 +08:00
12 changed files with 347 additions and 384 deletions

View File

@ -65,7 +65,7 @@ classDiagram
} }
class TrainConfig { class TrainConfig {
+nn.Module model +Callable[[], nn.Module] model_fn
+str strategy +str strategy
+Dataset dataset +Dataset dataset
+Callable optimizer_fn +Callable optimizer_fn
@ -108,7 +108,7 @@ classDiagram
+int window_size +int window_size
+int stride +int stride
+Optional[Store] storage +Optional[Store] storage
+load(load_path, storage_type, tokenizer) +load(load_path, storage_type)
+__getitem__(index) +__getitem__(index)
+__len__() +__len__()
} }
@ -134,7 +134,7 @@ classDiagram
+Dict[str, List[int]] _cum +Dict[str, List[int]] _cum
+int _length +int _length
+keys (property) +keys (property)
+load(path, tokenizer) +load(path)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
-_fetch_key(key, begin, end) Tensor -_fetch_key(key, begin, end) Tensor
@ -142,16 +142,12 @@ classDiagram
} }
class H5Store { class H5Store {
+load(path, tokenizer) +load(path)
}
class JSONStore {
+load(path, tokenizer)
} }
class MmapStore { class MmapStore {
+List _mmap_refs +List _mmap_refs
+load(path, tokenizer) +load(path)
} }
class ResumableDistributedSampler { class ResumableDistributedSampler {
@ -169,7 +165,7 @@ classDiagram
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
+create(train_type, window_size, stride) BaseDataset +create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset +load(train_type, load_path, window_size, stride, storage_type) BaseDataset
} }
} }
@ -180,8 +176,9 @@ classDiagram
+int iteration +int iteration
+dict extra +dict extra
+dict meta +dict meta
+dict config
+save(save_dir) +save(save_dir)
+load(save_dir) Checkpoint +load(save_dir, broadcast) Checkpoint
} }
} }
@ -189,8 +186,8 @@ classDiagram
class AutoModel { class AutoModel {
+BaseModelConfig config +BaseModelConfig config
+Registry _registry +Registry _registry
+register(model_type) decorator +register(name) decorator
+get_component_class(model_type) Type +get_component_class(name) Type
+from_pretrained(path, disable_random_init, strict) nn.Module +from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory) +save_pretrained(save_directory)
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
@ -204,7 +201,7 @@ classDiagram
+RMSNorm norm +RMSNorm norm
+Linear lm_head +Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor] +forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict) +load_state_dict(state_dict, strict, assign)
+state_dict() +state_dict()
} }
@ -229,6 +226,7 @@ classDiagram
} }
class GQA { class GQA {
+int dim
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
@ -243,6 +241,7 @@ classDiagram
} }
class MLA { class MLA {
+int dim
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
@ -303,6 +302,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+Optional[Dict] rope_scaling
+forward(x, position_ids=None) Tensor +forward(x, position_ids=None) Tensor
} }
@ -315,10 +315,10 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int] +encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]] +apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
+from_pretrained(path) AutoTokenizer +from_pretrained(path) AutoTokenizer
@ -326,7 +326,7 @@ classDiagram
} }
class ChatTemplate { class ChatTemplate {
+String template_str +str template_str
+render(messages, system_prompt, **extra_variables) str +render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate +from_string(template) ChatTemplate
} }
@ -364,6 +364,7 @@ classDiagram
+SchedulerProtocol scheduler +SchedulerProtocol scheduler
+Checkpoint checkpoint +Checkpoint checkpoint
+TrainConfig config +TrainConfig config
+dict model_config
+BaseExecutor executor +BaseExecutor executor
+int epoch +int epoch
+int iteration +int iteration
@ -377,7 +378,7 @@ classDiagram
class TrainContextBuilder { class TrainContextBuilder {
+TrainConfig config +TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder +with_resume_dir(resume_dir) TrainContextBuilder
+build() TrainContext +build() TrainContext
} }
@ -472,16 +473,12 @@ classDiagram
+str save_dir +str save_dir
+int interval +int interval
+bool weight_only +bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn +Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context) +_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context) +on_error(context)
+save_extra(context)$ +save_extra(context)$
+load_extra(extra, context)$
} }
class ProgressBarCallback { class ProgressBarCallback {
@ -518,7 +515,12 @@ classDiagram
+float lr +float lr
+float momentum +float momentum
+float weight_decay +float weight_decay
+bool nesterov
+int ns_steps +int ns_steps
+float adamw_lr
+tuple adamw_betas
+float adamw_eps
+float adamw_wd
+step(closure) Optional[float] +step(closure) Optional[float]
} }
} }
@ -539,6 +541,8 @@ classDiagram
+AutoModel model +AutoModel model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache page_cache +KVCache page_cache
+Optional[str] device
+Optional[torch.dtype] dtype
+execute_prefill(tasks, prompt_len, start_pos) +execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int] +execute_decode(tasks) List[int]
} }
@ -550,7 +554,9 @@ classDiagram
+bool _running +bool _running
+Thread _loop_thread +Thread _loop_thread
+int max_seq_len +int max_seq_len
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +str device
+torch.dtype dtype
+add_task(prompt, **kwargs) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
+stop() +stop()
@ -653,15 +659,19 @@ classDiagram
class TaskManager { class TaskManager {
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+Deque waiting_queue +Deque waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, **kwargs) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) List[Task] +remove_task(task_id) List[Task]
+remove_finished_tasks(stop_ids) List[Task] +remove_finished_tasks(stop_ids) List[Task]
+pull_candidates(n) List[Task] +pull_candidates(n) List[Task]
+activate(task) +activate(task)
+return_to_waiting(tasks) +return_to_waiting(tasks)
+get_active_tasks() List[Task] +get_active_tasks() List[Task]
+get_stats() Dict
} }
class GenerationRequest { class GenerationRequest {
@ -917,7 +927,6 @@ classDiagram
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset BaseDataset <|-- GRPODataset
Store <|-- H5Store Store <|-- H5Store
Store <|-- JSONStore
Store <|-- MmapStore Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
@ -996,7 +1005,6 @@ classDiagram
DecoderBlock ..> AttnFactory : uses DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses DecoderBlock ..> FFNFactory : uses
StoreFactory ..> H5Store : creates StoreFactory ..> H5Store : creates
StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates ConfigFactory ..> EncoderConfig : creates
@ -1063,7 +1071,7 @@ classDiagram
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1075,10 +1083,10 @@ classDiagram
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data` 7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
> Document Update Time: 2026-05-24 > Document Update Time: 2026-05-28

View File

@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview ## Overview
``` ```
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference
``` ```
## Data Preparation ## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups. Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
``` ```
StoreFactory.create("h5") → H5Store StoreFactory.create("h5") → H5Store
StoreFactory.create("json") → JSONStore
StoreFactory.create("bin") → MmapStore StoreFactory.create("bin") → MmapStore
``` ```
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type ## Data Keys by Training Type
@ -34,7 +33,7 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us
## Dataset Architecture ## Dataset Architecture
``` ```
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer) DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
→ StoreFactory.create(detect_format(path)) → StoreFactory.create(detect_format(path))
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] → Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx) → BaseDataset.__getitem__(idx)
@ -55,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokeniz
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-17 > Document Update Time: 2026-05-28

View File

@ -16,12 +16,12 @@ Six classes working together:
``` ```
KVCache (facade) KVCache (facade)
├── Allocator bitmask-based page allocator + ref-count + LRU eviction ├── PagePool orchestrates page allocation + prefix matching
├── PrefixCache hash-based prefix matching (page_hash via rolling hash) │ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
├── PagePool orchestrates Allocator + PrefixCache │ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── TaskTable maps task_id → page_table + cached token count ├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim) ├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers └── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
``` ```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`. `KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
@ -40,7 +40,10 @@ KVCache (facade)
## Sampling (Strategy Pattern) ## Sampling (Strategy Pattern)
``` ```
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
└── TopPStrategy
``` ```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
@ -50,11 +53,12 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
```python ```python
class ProtocolHandler: # concrete orchestrator class ProtocolHandler: # concrete orchestrator
def handle(self, request): def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine) prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...) agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops) if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx, stops) else: return await self._handle_non_stream(agen, ctx, stops)
``` ```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`. `ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
@ -96,12 +100,14 @@ Response:
{ {
"id": "chatcmpl-abc123", "id": "chatcmpl-abc123",
"object": "chat.completion", "object": "chat.completion",
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}], "created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
} }
``` ```
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]` Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
### Anthropic ### Anthropic
@ -121,7 +127,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) | | `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold | | `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count | | `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length | | `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output | | `stream` | bool | False | Stream output |
## Engine API ## Engine API
@ -139,4 +145,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
``` ```
> Document Update Time: 2026-05-17 > Document Update Time: 2026-05-28

View File

@ -74,7 +74,8 @@ on_train_begin
on_batch_begin on_batch_begin
with executor.accumulate(model): with executor.accumulate(model):
loss = strategy(batch) loss = strategy(batch)
(loss / grad_accum_steps).backward() stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
iteration += 1 iteration += 1
on_batch_end on_batch_end
@ -82,6 +83,7 @@ on_train_begin
on_optimizer_step on_optimizer_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
if scheduler:
scheduler.step() scheduler.step()
on_epoch_end on_epoch_end
on_train_end on_train_end
@ -169,20 +171,20 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
## Checkpoint ## Checkpoint
``` ```
Checkpoint(state_dict, epoch, iteration, extra, meta) Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt ├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir) broadcasts metadata from rank-0 └── load(save_dir) broadcasts metadata from rank-0
``` ```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`. Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`. Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern) ## TrainContextBuilder (Builder Pattern)
```python ```python
context = ( context = (
TrainContextBuilder(config) TrainContextBuilder(config)
.with_checkpoint(checkpoint) .with_resume_dir(resume_dir)
.build() .build()
) )
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint # Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
@ -222,4 +224,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md). Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-24 > Document Update Time: 2026-05-28

View File

@ -13,12 +13,21 @@ class BaseConfig:
d[fld.name] = v d[fld.name] = v
elif v is None: elif v is None:
d[fld.name] = None d[fld.name] = None
elif isinstance(v, (dict, list)): elif isinstance(v, (dict, list, tuple)):
try: try:
json.dumps(v) val = list(v) if isinstance(v, tuple) else v
d[fld.name] = v json.dumps(val)
d[fld.name] = val
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
elif isinstance(v, BaseConfig):
d[fld.name] = v.to_dict()
elif hasattr(v, "__dataclass_fields__"):
sub = {}
for f in fields(v):
a = getattr(v, f.name)
sub[f.name] = list(a) if isinstance(a, tuple) else a
d[fld.name] = sub
return d return d
@classmethod @classmethod

View File

@ -5,18 +5,14 @@ from astrai.dataset.dataset import (
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, H5Store,
JSONStore,
MmapStore, MmapStore,
Store, Store,
StoreFactory, StoreFactory,
detect_format, detect_format,
json_to_bin,
load_bin, load_bin,
load_h5, load_h5,
load_json,
save_bin, save_bin,
save_h5, save_h5,
save_json,
) )
__all__ = [ __all__ = [
@ -25,15 +21,11 @@ __all__ = [
"Store", "Store",
"StoreFactory", "StoreFactory",
"H5Store", "H5Store",
"JSONStore",
"MmapStore", "MmapStore",
"detect_format", "detect_format",
"save_h5", "save_h5",
"load_h5", "load_h5",
"save_json",
"load_json",
"save_bin", "save_bin",
"load_bin", "load_bin",
"json_to_bin",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -48,17 +48,15 @@ class BaseDataset(Dataset, ABC):
f"Missing: {missing}" f"Missing: {missing}"
) )
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None): def load(self, load_path: str, storage_type: Optional[str] = None):
"""Load dataset from the given path. """Load dataset from the given path.
Auto-detects the storage format if not specified. Auto-detects the storage format if not specified.
Args: Args:
load_path: Path to the data directory or file load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "json"), storage_type: Force a specific storage type ("h5", "bin"),
or None for auto-detection or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
Raises: Raises:
KeyError: If the loaded storage is missing required keys. KeyError: If the loaded storage is missing required keys.
@ -67,18 +65,9 @@ class BaseDataset(Dataset, ABC):
storage_type = detect_format(load_path) storage_type = detect_format(load_path)
self.storage = StoreFactory.create(storage_type) self.storage = StoreFactory.create(storage_type)
self._load_path = load_path self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer) self.storage.load(load_path)
self._validate_keys() self._validate_keys()
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
@property @property
def count(self) -> int: def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset.""" """Return the total number of raw elements (tokens) in the dataset."""
@ -175,7 +164,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
window_size: int, window_size: int,
stride: Optional[int] = None, stride: Optional[int] = None,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset": ) -> "BaseDataset":
"""Create and load a dataset in one step. """Create and load a dataset in one step.
@ -184,8 +172,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file load_path: Path to the data file
window_size: Window size for data sampling window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size) stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "json") or None for auto-detection storage_type: Storage type ("h5", "bin") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
Returns: Returns:
Loaded dataset instance Loaded dataset instance
@ -194,7 +181,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size stride = window_size
dataset = cls.create(train_type, window_size, stride) dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer) dataset.load(load_path, storage_type=storage_type)
return dataset return dataset
@ -306,9 +293,11 @@ class GRPODataset(BaseDataset):
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
prompts = self._fetch_data(begin_idx, end_idx, "prompts") prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
responses = self._fetch_data(begin_idx, end_idx, "responses") responses = self._fetch_data(begin_idx, end_idx, "responses").to(
masks = self._fetch_data(begin_idx, end_idx, "masks") dtype=torch.long
)
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
rewards = self._fetch_data(begin_idx, end_idx, "rewards") rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return { return {

View File

@ -1,7 +1,7 @@
"""Storage backends for different data formats. """Storage backends for different data formats.
Layers: Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin) - I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
return Dict[str, List[Tensor]] format-specific, no state return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into - Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(), Dict[str, List[Tensor]] per key via _normalize(),
@ -22,7 +22,7 @@ import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Dict, List, Union
import h5py import h5py
import numpy as np import numpy as np
@ -68,56 +68,6 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
return tensor_group return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]): def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
meta = {} meta = {}
@ -125,31 +75,25 @@ def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
cat = torch.cat(tensors, dim=0) cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]} meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin")) np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json")) with open(os.path.join(file_path, "meta.json"), "w") as f:
json.dump(meta, f)
def load_bin(file_path: str) -> Dict[str, List[Tensor]]: def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json")) with open(os.path.join(file_path, "meta.json"), "r") as f:
meta = json.load(f)
segments: Dict[str, List[Tensor]] = {} segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items(): for key, info in meta.items():
arr = np.memmap( arr = np.memmap(
os.path.join(file_path, f"{key}.bin"), os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"], dtype=info["dtype"],
mode="r", mode="r+",
shape=tuple(info["shape"]), shape=tuple(info["shape"]),
) )
segments[key] = [torch.from_numpy(arr)] segments[key] = [torch.from_numpy(arr)]
return segments return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str: def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory. """Auto-detect storage format from files in the directory.
@ -157,7 +101,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path load_path: Directory or file path
Returns: Returns:
Format string ("h5", "bin", or "json") Format string ("h5" or "bin")
Raises: Raises:
FileNotFoundError: If no supported data files are found FileNotFoundError: If no supported data files are found
@ -167,8 +111,6 @@ def detect_format(load_path: str) -> str:
suffix = root.suffix.lower() suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"): if suffix in (".h5", ".hdf5"):
return "h5" return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}") raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
@ -177,9 +119,6 @@ def detect_format(load_path: str) -> str:
bin_files = list(root.rglob("*.bin")) bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists(): if bin_files and (root / "meta.json").exists():
return "bin" return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}") raise FileNotFoundError(f"No supported data files found at {load_path}")
@ -200,7 +139,7 @@ class Store(ABC):
self._length: int = 0 self._length: int = 0
@abstractmethod @abstractmethod
def load(self, path: str, tokenizer=None) -> None: def load(self, path: str) -> None:
raise NotImplementedError raise NotImplementedError
@property @property
@ -257,7 +196,11 @@ class Store(ABC):
total += t.shape[0] total += t.shape[0]
cum.append(total) cum.append(total)
self._cum[key] = cum self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0 self._length = (
min((cum[-1] if cum else 0) for cum in self._cum.values())
if self._cum
else 0
)
class StoreFactory(BaseFactory["Store"]): class StoreFactory(BaseFactory["Store"]):
@ -280,24 +223,10 @@ class StoreFactory(BaseFactory["Store"]):
class H5Store(Store): class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data).""" """HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str, tokenizer=None): def load(self, path: str):
self._normalize(load_h5(path)) self._normalize(load_h5(path))
@StoreFactory.register("json")
class JSONStore(Store):
"""JSON-based storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, path: str, tokenizer=None):
self._normalize(load_json(path, tokenizer=tokenizer))
@StoreFactory.register("bin") @StoreFactory.register("bin")
class MmapStore(Store): class MmapStore(Store):
"""Memory-mapped binary storage backend. """Memory-mapped binary storage backend.
@ -313,7 +242,7 @@ class MmapStore(Store):
<key>.bin # raw numpy array, one per key <key>.bin # raw numpy array, one per key
""" """
def load(self, path: str, tokenizer=None): def load(self, path: str):
self._mmap_refs = [] self._mmap_refs = []
raw = load_bin(path) raw = load_bin(path)
self._normalize(raw) self._normalize(raw)

View File

@ -3,7 +3,7 @@ import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, Union
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -16,29 +16,50 @@ _CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors" _WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: str | Path): def save_safetensors(state_dict: dict, path: Union[str, Path]):
st.save_file(state_dict, str(path)) st.save_file(state_dict, str(path))
def load_safetensors(path: str | Path) -> dict: def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return st.load_file(str(path)) return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: str | Path):
def save_json(data: dict, path: Union[str, Path]):
with open(str(path), "w") as f: with open(str(path), "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
def load_json(path: str | Path) -> dict: def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
with open(str(path), "r") as f: with open(str(path), "r") as f:
return json.load(f) return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: str | Path):
def save_torch(obj: Any, path: Union[str, Path]):
torch.save(obj, str(path)) torch.save(obj, str(path))
def load_torch(path: str | Path, broadcast: bool = False) -> Any: def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized(): if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False) return torch.load(str(path), map_location="cpu", weights_only=False)
@ -76,28 +97,18 @@ def load_model_config(save_directory: str) -> dict:
def load_model_weights(save_directory: str) -> dict: def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict: def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
meta = {} path = Path(path)
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized(): if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE) return load_safetensors(path)
rank = get_rank() rank = get_rank()
if rank == 0: if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE) state_dict = load_safetensors(path)
specs: List[Tuple[str, List[int], str]] = [ specs = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1]) (k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict) for k in sorted(state_dict)
] ]
@ -128,6 +139,7 @@ class Checkpoint:
iteration: int = 0 iteration: int = 0
extra: Dict[str, Any] = field(default_factory=dict) extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
config: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str): def save(self, save_dir: str):
save_path = Path(save_dir) save_path = Path(save_dir)
@ -143,6 +155,7 @@ class Checkpoint:
**self.meta, **self.meta,
} }
save_json(meta, save_path / _META_FILE) save_json(meta, save_path / _META_FILE)
save_json(self.config, save_path / _CONFIG_FILE)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE) save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
for key, value in self.extra.items(): for key, value in self.extra.items():
save_torch(value, save_path / f"{key}.pt") save_torch(value, save_path / f"{key}.pt")
@ -151,8 +164,9 @@ class Checkpoint:
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir) save_path = Path(save_dir)
meta = _get_meta(save_path) meta = load_json(save_path / _META_FILE, broadcast)
state_dict = _load_state_dict(save_path, broadcast=broadcast) config = load_json(save_path / _CONFIG_FILE, broadcast)
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
extra = {} extra = {}
for f in sorted(save_path.iterdir()): for f in sorted(save_path.iterdir()):
@ -164,4 +178,5 @@ class Checkpoint:
epoch=meta.get("epoch", 0), epoch=meta.get("epoch", 0),
iteration=meta.get("iteration", 0), iteration=meta.get("iteration", 0),
extra=extra, extra=extra,
config=config,
) )

View File

@ -137,23 +137,17 @@ class CheckpointCallback(TrainCallback):
save_dir: str, save_dir: str,
interval: int, interval: int,
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
# All ranks gather state_dict — collective for FSDP, local for DDP unwrapped = context.executor.unwrap_model(context.model)
state_dict = ( state_dict = unwrapped.state_dict()
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
if get_rank() == 0: if get_rank() == 0:
@ -166,7 +160,7 @@ class CheckpointCallback(TrainCallback):
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration, iteration=context.iteration,
extra=extra, extra=extra,
meta=context.config.to_dict(), config=context.model_config,
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)

View File

@ -11,7 +11,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_model_weights from astrai.serialization import Checkpoint, load_json, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -24,6 +24,7 @@ class TrainContext:
scheduler: SchedulerProtocol = field(default=None) scheduler: SchedulerProtocol = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None) config: TrainConfig = field(default=None)
model_config: dict = field(default_factory=dict)
executor: BaseExecutor = field(default=None) executor: BaseExecutor = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
@ -62,11 +63,21 @@ class TrainContextBuilder:
model = cfg.model_fn() model = cfg.model_fn()
model = model.to(device=device) model = model.to(device=device)
model_config = {}
if self._resume_dir:
config_path = Path(self._resume_dir) / "config.json"
if config_path.exists():
model_config = load_json(config_path)
if not model_config and hasattr(model, "config"):
model_config = model.config.to_dict()
context = TrainContext( context = TrainContext(
model=model, model=model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=cfg, config=cfg,
model_config=model_config,
executor=executor, executor=executor,
) )
@ -75,13 +86,15 @@ class TrainContextBuilder:
if (resume_path / "meta.json").exists(): if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir) checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict state_dict = checkpoint.state_dict
if checkpoint.config:
context.model_config = checkpoint.config
else: else:
checkpoint = None checkpoint = None
state_dict = load_model_weights(self._resume_dir) state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
if checkpoint is not None: if checkpoint is not None:
context.epoch = max(checkpoint.epoch, cfg.start_epoch) context.epoch = cfg.start_epoch
context.iteration = max(checkpoint.iteration, cfg.start_batch) context.iteration = cfg.start_batch
context.checkpoint = checkpoint context.checkpoint = checkpoint
if cfg.lora is not None: if cfg.lora is not None:

View File

@ -8,9 +8,11 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, H5Store,
MmapStore,
StoreFactory, StoreFactory,
detect_format, detect_format,
load_json, load_bin,
save_bin,
save_h5, save_h5,
) )
@ -155,111 +157,6 @@ def test_dataset_with_custom_stride(base_test_env):
assert len(dataset) > len(default_stride_dataset) assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env): def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count""" """Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
@ -334,25 +231,6 @@ def test_store_fetch_begin_equals_end(base_test_env):
assert result.numel() == 0 assert result.numel() == 0
def test_store_empty_data_len(base_test_env):
"""Store loaded with empty data has __len__ == 0"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_store_fetch_before_load(): def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError""" """Store.fetch before load raises RuntimeError"""
store = H5Store() store = H5Store()
@ -382,40 +260,6 @@ def test_create_store_invalid_type():
StoreFactory.create("parquet") StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_pretok")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_store_multi_segment_concat(base_test_env): def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time""" """Multi-segment H5 data is concatenated into single tensor at load time"""
import os import os
@ -436,3 +280,166 @@ def test_store_multi_segment_concat(base_test_env):
assert len(store) == 9 assert len(store) == 9
result = store.fetch(2, 7, "sequence") result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7] assert result.tolist() == [3, 4, 5, 6, 7]
def test_save_load_bin_roundtrip(base_test_env):
"""save_bin + load_bin roundtrip preserves data"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
}
save_bin(test_dir, data)
result = load_bin(test_dir)
assert "sequence" in result
assert "loss_mask" in result
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
def test_mmap_store_load_and_fetch(base_test_env):
"""MmapStore loads bin data and fetches correctly"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
store = StoreFactory.create("bin")
store.load(test_dir)
assert len(store) == 200
assert "sequence" in store.keys
result = store.fetch(10, 20, "sequence")
assert result.tolist() == data["sequence"][0][10:20].tolist()
def test_mmap_dataset_load(base_test_env):
"""DatasetFactory.load auto-detects bin format"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) > 0
assert dataset.count == 200
assert dataset[0]["input_ids"].shape[0] == 64
def test_normalize_empty_key():
"""_normalize with empty tensor list does not crash"""
store = H5Store()
store._normalize({"sequence": []})
assert len(store) == 0
assert store.keys == ["sequence"]
def test_normalize_mixed_empty_key():
"""_normalize with empty + non-empty keys returns min=0"""
store = H5Store()
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
assert len(store) == 0
assert set(store.keys) == {"sequence", "loss_mask"}
def test_grpo_dataset_dtype(base_test_env):
"""GRPODataset returns correct dtypes"""
test_dir = base_test_env["test_dir"]
seq_len = 100
data = {
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"masks": [torch.ones(seq_len, dtype=torch.int32)],
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_dtype", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
item = dataset[0]
assert item["prompts"].dtype == torch.long
assert item["responses"].dtype == torch.long
assert item["masks"].dtype == torch.bool
assert item["rewards"].dtype == torch.float32
def test_grpo_dataset_load(base_test_env):
"""GRPODataset loads and returns correct keys"""
test_dir = base_test_env["test_dir"]
seq_len = 200
data = {
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"masks": [torch.ones(seq_len, dtype=torch.int64)],
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_test", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
assert len(dataset) > 0
item = dataset[0]
assert "prompts" in item
assert "responses" in item
assert "masks" in item
assert "rewards" in item
assert item["prompts"].shape[0] == 64
assert item["responses"].shape[0] == 64
def test_detect_format_bin_dir(base_test_env):
"""detect_format returns 'bin' for directory with .bin + meta.json"""
test_dir = base_test_env["test_dir"]
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
assert detect_format(test_dir) == "bin"
def test_store_fetch_multi_key(base_test_env):
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
test_dir = base_test_env["test_dir"]
save_h5(
test_dir,
"multi_key",
{
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
"loss_mask": [torch.ones(100, dtype=torch.int64)],
},
)
store = StoreFactory.create("h5")
store.load(test_dir)
result = store.fetch(10, 20, ["sequence", "loss_mask"])
assert isinstance(result, dict)
assert result["sequence"].shape[0] == 10
assert result["loss_mask"].shape[0] == 10
def test_store_fetch_out_of_bounds(base_test_env):
"""Store.fetch raises ValueError for out-of-bounds indices"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
store = StoreFactory.create("h5")
store.load(test_dir)
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(-1, 10, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(0, 51, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(50, 50, "sequence")
def test_dataset_load_explicit_storage_type(base_test_env):
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
assert len(dataset) > 0
assert dataset.count == 200