Compare commits

..

No commits in common. "b37c3d000c3cbb4710993828bfd9353650e06e9e" and "0a708fff24b32f17a859bc6fcf3ba8f03b644aab" have entirely different histories.

12 changed files with 384 additions and 347 deletions

View File

@ -65,7 +65,7 @@ classDiagram
} }
class TrainConfig { class TrainConfig {
+Callable[[], nn.Module] model_fn +nn.Module model
+str strategy +str strategy
+Dataset dataset +Dataset dataset
+Callable optimizer_fn +Callable optimizer_fn
@ -108,7 +108,7 @@ classDiagram
+int window_size +int window_size
+int stride +int stride
+Optional[Store] storage +Optional[Store] storage
+load(load_path, storage_type) +load(load_path, storage_type, tokenizer)
+__getitem__(index) +__getitem__(index)
+__len__() +__len__()
} }
@ -134,7 +134,7 @@ classDiagram
+Dict[str, List[int]] _cum +Dict[str, List[int]] _cum
+int _length +int _length
+keys (property) +keys (property)
+load(path) +load(path, tokenizer)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
-_fetch_key(key, begin, end) Tensor -_fetch_key(key, begin, end) Tensor
@ -142,12 +142,16 @@ classDiagram
} }
class H5Store { class H5Store {
+load(path) +load(path, tokenizer)
}
class JSONStore {
+load(path, tokenizer)
} }
class MmapStore { class MmapStore {
+List _mmap_refs +List _mmap_refs
+load(path) +load(path, tokenizer)
} }
class ResumableDistributedSampler { class ResumableDistributedSampler {
@ -165,7 +169,7 @@ classDiagram
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
+create(train_type, window_size, stride) BaseDataset +create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset +load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
} }
} }
@ -176,9 +180,8 @@ classDiagram
+int iteration +int iteration
+dict extra +dict extra
+dict meta +dict meta
+dict config
+save(save_dir) +save(save_dir)
+load(save_dir, broadcast) Checkpoint +load(save_dir) Checkpoint
} }
} }
@ -186,8 +189,8 @@ classDiagram
class AutoModel { class AutoModel {
+BaseModelConfig config +BaseModelConfig config
+Registry _registry +Registry _registry
+register(name) decorator +register(model_type) decorator
+get_component_class(name) Type +get_component_class(model_type) Type
+from_pretrained(path, disable_random_init, strict) nn.Module +from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory) +save_pretrained(save_directory)
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
@ -201,7 +204,7 @@ classDiagram
+RMSNorm norm +RMSNorm norm
+Linear lm_head +Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor] +forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict, strict, assign) +load_state_dict(state_dict)
+state_dict() +state_dict()
} }
@ -226,7 +229,6 @@ classDiagram
} }
class GQA { class GQA {
+int dim
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
@ -241,7 +243,6 @@ classDiagram
} }
class MLA { class MLA {
+int dim
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
@ -302,7 +303,6 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+Optional[Dict] rope_scaling
+forward(x, position_ids=None) Tensor +forward(x, position_ids=None) Tensor
} }
@ -315,10 +315,10 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+vocab_size int +vocab_size int
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int] +encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]] +apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
+from_pretrained(path) AutoTokenizer +from_pretrained(path) AutoTokenizer
@ -326,7 +326,7 @@ classDiagram
} }
class ChatTemplate { class ChatTemplate {
+str template_str +String template_str
+render(messages, system_prompt, **extra_variables) str +render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate +from_string(template) ChatTemplate
} }
@ -364,7 +364,6 @@ classDiagram
+SchedulerProtocol scheduler +SchedulerProtocol scheduler
+Checkpoint checkpoint +Checkpoint checkpoint
+TrainConfig config +TrainConfig config
+dict model_config
+BaseExecutor executor +BaseExecutor executor
+int epoch +int epoch
+int iteration +int iteration
@ -378,7 +377,7 @@ classDiagram
class TrainContextBuilder { class TrainContextBuilder {
+TrainConfig config +TrainConfig config
+with_resume_dir(resume_dir) TrainContextBuilder +with_checkpoint(checkpoint) TrainContextBuilder
+build() TrainContext +build() TrainContext
} }
@ -473,12 +472,16 @@ classDiagram
+str save_dir +str save_dir
+int interval +int interval
+bool weight_only +bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn +Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context) +_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context) +on_error(context)
+save_extra(context)$ +save_extra(context)$
+load_extra(extra, context)$
} }
class ProgressBarCallback { class ProgressBarCallback {
@ -515,12 +518,7 @@ classDiagram
+float lr +float lr
+float momentum +float momentum
+float weight_decay +float weight_decay
+bool nesterov
+int ns_steps +int ns_steps
+float adamw_lr
+tuple adamw_betas
+float adamw_eps
+float adamw_wd
+step(closure) Optional[float] +step(closure) Optional[float]
} }
} }
@ -541,8 +539,6 @@ classDiagram
+AutoModel model +AutoModel model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache page_cache +KVCache page_cache
+Optional[str] device
+Optional[torch.dtype] dtype
+execute_prefill(tasks, prompt_len, start_pos) +execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int] +execute_decode(tasks) List[int]
} }
@ -554,9 +550,7 @@ classDiagram
+bool _running +bool _running
+Thread _loop_thread +Thread _loop_thread
+int max_seq_len +int max_seq_len
+str device +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+torch.dtype dtype
+add_task(prompt, **kwargs) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
+stop() +stop()
@ -659,19 +653,15 @@ classDiagram
class TaskManager { class TaskManager {
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+Deque waiting_queue +Deque waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, **kwargs) str
+remove_task(task_id) List[Task] +remove_task(task_id) List[Task]
+remove_finished_tasks(stop_ids) List[Task] +remove_finished_tasks(stop_ids) List[Task]
+pull_candidates(n) List[Task] +pull_candidates(n) List[Task]
+activate(task) +activate(task)
+return_to_waiting(tasks) +return_to_waiting(tasks)
+get_active_tasks() List[Task] +get_active_tasks() List[Task]
+get_stats() Dict
} }
class GenerationRequest { class GenerationRequest {
@ -927,6 +917,7 @@ classDiagram
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset BaseDataset <|-- GRPODataset
Store <|-- H5Store Store <|-- H5Store
Store <|-- JSONStore
Store <|-- MmapStore Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
@ -1005,6 +996,7 @@ classDiagram
DecoderBlock ..> AttnFactory : uses DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses DecoderBlock ..> FFNFactory : uses
StoreFactory ..> H5Store : creates StoreFactory ..> H5Store : creates
StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates ConfigFactory ..> EncoderConfig : creates
@ -1071,7 +1063,7 @@ classDiagram
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1083,10 +1075,10 @@ classDiagram
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data` 7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-24

View File

@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview ## Overview
``` ```
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
``` ```
## Data Preparation ## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups. Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
``` ```
StoreFactory.create("h5") → H5Store StoreFactory.create("h5") → H5Store
StoreFactory.create("bin") → MmapStore StoreFactory.create("json") → JSONStore
StoreFactory.create("bin") → MmapStore
``` ```
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type ## Data Keys by Training Type
@ -33,7 +34,7 @@ H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS pag
## Dataset Architecture ## Dataset Architecture
``` ```
DatasetFactory.load(train_type, load_path, window_size, stride, storage_type) DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
→ StoreFactory.create(detect_format(path)) → StoreFactory.create(detect_format(path))
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] → Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx) → BaseDataset.__getitem__(idx)
@ -54,4 +55,4 @@ DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-17

View File

@ -16,12 +16,12 @@ Six classes working together:
``` ```
KVCache (facade) KVCache (facade)
├── PagePool orchestrates page allocation + prefix matching ├── Allocator bitmask-based page allocator + ref-count + LRU eviction
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool) ├── PrefixCache hash-based prefix matching (page_hash via rolling hash)
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool) ├── PagePool orchestrates Allocator + PrefixCache
├── TaskTable maps task_id → page_table + cached token count ├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim) ├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind()) └── KvcacheView bundles Storage + page_table + total_len for attention layers
``` ```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`. `KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
@ -40,10 +40,7 @@ KVCache (facade)
## Sampling (Strategy Pattern) ## Sampling (Strategy Pattern)
``` ```
BaseSamplingStrategy (ABC) BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
├── TemperatureStrategy
├── TopKStrategy
└── TopPStrategy
``` ```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
@ -53,12 +50,11 @@ BaseSamplingStrategy (ABC)
```python ```python
class ProtocolHandler: # concrete orchestrator class ProtocolHandler: # concrete orchestrator
def __init__(self, request, engine, builder): ... def handle(self, request):
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine) prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...) agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops) if stream: self._handle_stream(agen, ctx, stops)
else: return await self._handle_non_stream(agen, ctx, stops) else: self._handle_non_stream(agen, ctx, stops)
``` ```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`. `ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
@ -100,14 +96,12 @@ Response:
{ {
"id": "chatcmpl-abc123", "id": "chatcmpl-abc123",
"object": "chat.completion", "object": "chat.completion",
"created": 1717000000, "choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
} }
``` ```
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`. Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]`
### Anthropic ### Anthropic
@ -127,7 +121,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) | | `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold | | `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count | | `top_k` | int | 50 | Top-k count |
| `max_tokens` | Optional[int] | None | Max generation length | | `max_tokens` | int | None | Max generation length |
| `stream` | bool | False | Stream output | | `stream` | bool | False | Stream output |
## Engine API ## Engine API
@ -145,4 +139,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
``` ```
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-17

View File

@ -74,17 +74,15 @@ on_train_begin
on_batch_begin on_batch_begin
with executor.accumulate(model): with executor.accumulate(model):
loss = strategy(batch) loss = strategy(batch)
stand_loss = loss / executor.grad_accum_steps (loss / grad_accum_steps).backward()
executor.backward(stand_loss)
iteration += 1 iteration += 1
on_batch_end on_batch_end
if executor.sync_gradients: if executor.sync_gradients:
on_optimizer_step on_optimizer_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
if scheduler: scheduler.step()
scheduler.step()
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
@ -171,20 +169,20 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
## Checkpoint ## Checkpoint
``` ```
Checkpoint(state_dict, epoch, iteration, extra, meta, config) Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt) ├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
└── load(save_dir) broadcasts metadata from rank-0 └── load(save_dir) broadcasts metadata from rank-0
``` ```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`. Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`. Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern) ## TrainContextBuilder (Builder Pattern)
```python ```python
context = ( context = (
TrainContextBuilder(config) TrainContextBuilder(config)
.with_resume_dir(resume_dir) .with_checkpoint(checkpoint)
.build() .build()
) )
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint # Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
@ -224,4 +222,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md). Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-28 > Document Update Time: 2026-05-24

View File

@ -13,21 +13,12 @@ class BaseConfig:
d[fld.name] = v d[fld.name] = v
elif v is None: elif v is None:
d[fld.name] = None d[fld.name] = None
elif isinstance(v, (dict, list, tuple)): elif isinstance(v, (dict, list)):
try: try:
val = list(v) if isinstance(v, tuple) else v json.dumps(v)
json.dumps(val) d[fld.name] = v
d[fld.name] = val
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
elif isinstance(v, BaseConfig):
d[fld.name] = v.to_dict()
elif hasattr(v, "__dataclass_fields__"):
sub = {}
for f in fields(v):
a = getattr(v, f.name)
sub[f.name] = list(a) if isinstance(a, tuple) else a
d[fld.name] = sub
return d return d
@classmethod @classmethod

View File

@ -5,14 +5,18 @@ from astrai.dataset.dataset import (
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, H5Store,
JSONStore,
MmapStore, MmapStore,
Store, Store,
StoreFactory, StoreFactory,
detect_format, detect_format,
json_to_bin,
load_bin, load_bin,
load_h5, load_h5,
load_json,
save_bin, save_bin,
save_h5, save_h5,
save_json,
) )
__all__ = [ __all__ = [
@ -21,11 +25,15 @@ __all__ = [
"Store", "Store",
"StoreFactory", "StoreFactory",
"H5Store", "H5Store",
"JSONStore",
"MmapStore", "MmapStore",
"detect_format", "detect_format",
"save_h5", "save_h5",
"load_h5", "load_h5",
"save_json",
"load_json",
"save_bin", "save_bin",
"load_bin", "load_bin",
"json_to_bin",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -48,15 +48,17 @@ class BaseDataset(Dataset, ABC):
f"Missing: {missing}" f"Missing: {missing}"
) )
def load(self, load_path: str, storage_type: Optional[str] = None): def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
"""Load dataset from the given path. """Load dataset from the given path.
Auto-detects the storage format if not specified. Auto-detects the storage format if not specified.
Args: Args:
load_path: Path to the data directory or file load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "bin"), storage_type: Force a specific storage type ("h5", "json"),
or None for auto-detection or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
Raises: Raises:
KeyError: If the loaded storage is missing required keys. KeyError: If the loaded storage is missing required keys.
@ -65,9 +67,18 @@ class BaseDataset(Dataset, ABC):
storage_type = detect_format(load_path) storage_type = detect_format(load_path)
self.storage = StoreFactory.create(storage_type) self.storage = StoreFactory.create(storage_type)
self._load_path = load_path self._load_path = load_path
self.storage.load(load_path) self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys() self._validate_keys()
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
@property @property
def count(self) -> int: def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset.""" """Return the total number of raw elements (tokens) in the dataset."""
@ -164,6 +175,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
window_size: int, window_size: int,
stride: Optional[int] = None, stride: Optional[int] = None,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset": ) -> "BaseDataset":
"""Create and load a dataset in one step. """Create and load a dataset in one step.
@ -172,7 +184,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file load_path: Path to the data file
window_size: Window size for data sampling window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size) stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "bin") or None for auto-detection storage_type: Storage type ("h5", "json") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
Returns: Returns:
Loaded dataset instance Loaded dataset instance
@ -181,7 +194,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size stride = window_size
dataset = cls.create(train_type, window_size, stride) dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type) dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
return dataset return dataset
@ -293,11 +306,9 @@ class GRPODataset(BaseDataset):
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long) prompts = self._fetch_data(begin_idx, end_idx, "prompts")
responses = self._fetch_data(begin_idx, end_idx, "responses").to( responses = self._fetch_data(begin_idx, end_idx, "responses")
dtype=torch.long masks = self._fetch_data(begin_idx, end_idx, "masks")
)
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
rewards = self._fetch_data(begin_idx, end_idx, "rewards") rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return { return {

View File

@ -1,20 +1,20 @@
"""Storage backends for different data formats. """Storage backends for different data formats.
Layers: Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin) - I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
return Dict[str, List[Tensor]] format-specific, no state return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into - Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(), Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key) - Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties: Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for - Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load, - Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1) __len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader - Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages workers share OS page-cache pages
""" """
import bisect import bisect
@ -22,7 +22,7 @@ import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, List, Union from typing import Callable, Dict, List, Optional, Union
import h5py import h5py
import numpy as np import numpy as np
@ -68,6 +68,56 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
return tensor_group return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]): def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
meta = {} meta = {}
@ -75,25 +125,31 @@ def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
cat = torch.cat(tensors, dim=0) cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]} meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin")) np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
with open(os.path.join(file_path, "meta.json"), "w") as f: save_json(meta, os.path.join(file_path, "meta.json"))
json.dump(meta, f)
def load_bin(file_path: str) -> Dict[str, List[Tensor]]: def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
with open(os.path.join(file_path, "meta.json"), "r") as f: meta = load_json(os.path.join(file_path, "meta.json"))
meta = json.load(f)
segments: Dict[str, List[Tensor]] = {} segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items(): for key, info in meta.items():
arr = np.memmap( arr = np.memmap(
os.path.join(file_path, f"{key}.bin"), os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"], dtype=info["dtype"],
mode="r+", mode="r",
shape=tuple(info["shape"]), shape=tuple(info["shape"]),
) )
segments[key] = [torch.from_numpy(arr)] segments[key] = [torch.from_numpy(arr)]
return segments return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str: def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory. """Auto-detect storage format from files in the directory.
@ -101,7 +157,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path load_path: Directory or file path
Returns: Returns:
Format string ("h5" or "bin") Format string ("h5", "bin", or "json")
Raises: Raises:
FileNotFoundError: If no supported data files are found FileNotFoundError: If no supported data files are found
@ -111,6 +167,8 @@ def detect_format(load_path: str) -> str:
suffix = root.suffix.lower() suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"): if suffix in (".h5", ".hdf5"):
return "h5" return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}") raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
@ -119,6 +177,9 @@ def detect_format(load_path: str) -> str:
bin_files = list(root.rglob("*.bin")) bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists(): if bin_files and (root / "meta.json").exists():
return "bin" return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}") raise FileNotFoundError(f"No supported data files found at {load_path}")
@ -139,7 +200,7 @@ class Store(ABC):
self._length: int = 0 self._length: int = 0
@abstractmethod @abstractmethod
def load(self, path: str) -> None: def load(self, path: str, tokenizer=None) -> None:
raise NotImplementedError raise NotImplementedError
@property @property
@ -196,11 +257,7 @@ class Store(ABC):
total += t.shape[0] total += t.shape[0]
cum.append(total) cum.append(total)
self._cum[key] = cum self._cum[key] = cum
self._length = ( self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
min((cum[-1] if cum else 0) for cum in self._cum.values())
if self._cum
else 0
)
class StoreFactory(BaseFactory["Store"]): class StoreFactory(BaseFactory["Store"]):
@ -223,10 +280,24 @@ class StoreFactory(BaseFactory["Store"]):
class H5Store(Store): class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data).""" """HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str): def load(self, path: str, tokenizer=None):
self._normalize(load_h5(path)) self._normalize(load_h5(path))
@StoreFactory.register("json")
class JSONStore(Store):
"""JSON-based storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, path: str, tokenizer=None):
self._normalize(load_json(path, tokenizer=tokenizer))
@StoreFactory.register("bin") @StoreFactory.register("bin")
class MmapStore(Store): class MmapStore(Store):
"""Memory-mapped binary storage backend. """Memory-mapped binary storage backend.
@ -242,7 +313,7 @@ class MmapStore(Store):
<key>.bin # raw numpy array, one per key <key>.bin # raw numpy array, one per key
""" """
def load(self, path: str): def load(self, path: str, tokenizer=None):
self._mmap_refs = [] self._mmap_refs = []
raw = load_bin(path) raw = load_bin(path)
self._normalize(raw) self._normalize(raw)

View File

@ -3,7 +3,7 @@ import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Union from typing import Any, Dict, List, Tuple
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -16,50 +16,29 @@ _CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors" _WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: Union[str, Path]): def save_safetensors(state_dict: dict, path: str | Path):
st.save_file(state_dict, str(path)) st.save_file(state_dict, str(path))
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict: def load_safetensors(path: str | Path) -> dict:
if not broadcast or not dist.is_initialized(): return st.load_file(str(path))
return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: Union[str, Path]): def save_json(data: dict, path: str | Path):
with open(str(path), "w") as f: with open(str(path), "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict: def load_json(path: str | Path) -> dict:
if not broadcast or not dist.is_initialized(): with open(str(path), "r") as f:
with open(str(path), "r") as f: return json.load(f)
return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: Union[str, Path]): def save_torch(obj: Any, path: str | Path):
torch.save(obj, str(path)) torch.save(obj, str(path))
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any: def load_torch(path: str | Path, broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized(): if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False) return torch.load(str(path), map_location="cpu", weights_only=False)
@ -97,18 +76,28 @@ def load_model_config(save_directory: str) -> dict:
def load_model_weights(save_directory: str) -> dict: def load_model_weights(save_directory: str) -> dict:
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE) return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict: def _get_meta(save_path: Path) -> dict:
path = Path(path) meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized(): if not broadcast or not dist.is_initialized():
return load_safetensors(path) return load_safetensors(save_path / _WEIGHTS_FILE)
rank = get_rank() rank = get_rank()
if rank == 0: if rank == 0:
state_dict = load_safetensors(path) state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs = [ specs: List[Tuple[str, List[int], str]] = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1]) (k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict) for k in sorted(state_dict)
] ]
@ -139,7 +128,6 @@ class Checkpoint:
iteration: int = 0 iteration: int = 0
extra: Dict[str, Any] = field(default_factory=dict) extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
config: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str): def save(self, save_dir: str):
save_path = Path(save_dir) save_path = Path(save_dir)
@ -155,7 +143,6 @@ class Checkpoint:
**self.meta, **self.meta,
} }
save_json(meta, save_path / _META_FILE) save_json(meta, save_path / _META_FILE)
save_json(self.config, save_path / _CONFIG_FILE)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE) save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
for key, value in self.extra.items(): for key, value in self.extra.items():
save_torch(value, save_path / f"{key}.pt") save_torch(value, save_path / f"{key}.pt")
@ -164,9 +151,8 @@ class Checkpoint:
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir) save_path = Path(save_dir)
meta = load_json(save_path / _META_FILE, broadcast) meta = _get_meta(save_path)
config = load_json(save_path / _CONFIG_FILE, broadcast) state_dict = _load_state_dict(save_path, broadcast=broadcast)
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
extra = {} extra = {}
for f in sorted(save_path.iterdir()): for f in sorted(save_path.iterdir()):
@ -178,5 +164,4 @@ class Checkpoint:
epoch=meta.get("epoch", 0), epoch=meta.get("epoch", 0),
iteration=meta.get("iteration", 0), iteration=meta.get("iteration", 0),
extra=extra, extra=extra,
config=config,
) )

View File

@ -137,17 +137,23 @@ class CheckpointCallback(TrainCallback):
save_dir: str, save_dir: str,
interval: int, interval: int,
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
unwrapped = context.executor.unwrap_model(context.model) # All ranks gather state_dict — collective for FSDP, local for DDP
state_dict = unwrapped.state_dict() state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
if get_rank() == 0: if get_rank() == 0:
@ -160,7 +166,7 @@ class CheckpointCallback(TrainCallback):
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration, iteration=context.iteration,
extra=extra, extra=extra,
config=context.model_config, meta=context.config.to_dict(),
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)

View File

@ -11,7 +11,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_json, load_model_weights from astrai.serialization import Checkpoint, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -24,7 +24,6 @@ class TrainContext:
scheduler: SchedulerProtocol = field(default=None) scheduler: SchedulerProtocol = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None) config: TrainConfig = field(default=None)
model_config: dict = field(default_factory=dict)
executor: BaseExecutor = field(default=None) executor: BaseExecutor = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
@ -63,21 +62,11 @@ class TrainContextBuilder:
model = cfg.model_fn() model = cfg.model_fn()
model = model.to(device=device) model = model.to(device=device)
model_config = {}
if self._resume_dir:
config_path = Path(self._resume_dir) / "config.json"
if config_path.exists():
model_config = load_json(config_path)
if not model_config and hasattr(model, "config"):
model_config = model.config.to_dict()
context = TrainContext( context = TrainContext(
model=model, model=model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=cfg, config=cfg,
model_config=model_config,
executor=executor, executor=executor,
) )
@ -86,15 +75,13 @@ class TrainContextBuilder:
if (resume_path / "meta.json").exists(): if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir) checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict state_dict = checkpoint.state_dict
if checkpoint.config:
context.model_config = checkpoint.config
else: else:
checkpoint = None checkpoint = None
state_dict = load_model_weights(self._resume_dir) state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
if checkpoint is not None: if checkpoint is not None:
context.epoch = cfg.start_epoch context.epoch = max(checkpoint.epoch, cfg.start_epoch)
context.iteration = cfg.start_batch context.iteration = max(checkpoint.iteration, cfg.start_batch)
context.checkpoint = checkpoint context.checkpoint = checkpoint
if cfg.lora is not None: if cfg.lora is not None:

View File

@ -8,11 +8,9 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, H5Store,
MmapStore,
StoreFactory, StoreFactory,
detect_format, detect_format,
load_bin, load_json,
save_bin,
save_h5, save_h5,
) )
@ -157,6 +155,111 @@ def test_dataset_with_custom_stride(base_test_env):
assert len(dataset) > len(default_stride_dataset) assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env): def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count""" """Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
@ -231,6 +334,25 @@ def test_store_fetch_begin_equals_end(base_test_env):
assert result.numel() == 0 assert result.numel() == 0
def test_store_empty_data_len(base_test_env):
"""Store loaded with empty data has __len__ == 0"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_store_fetch_before_load(): def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError""" """Store.fetch before load raises RuntimeError"""
store = H5Store() store = H5Store()
@ -260,6 +382,40 @@ def test_create_store_invalid_type():
StoreFactory.create("parquet") StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_pretok")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_store_multi_segment_concat(base_test_env): def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time""" """Multi-segment H5 data is concatenated into single tensor at load time"""
import os import os
@ -280,166 +436,3 @@ def test_store_multi_segment_concat(base_test_env):
assert len(store) == 9 assert len(store) == 9
result = store.fetch(2, 7, "sequence") result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7] assert result.tolist() == [3, 4, 5, 6, 7]
def test_save_load_bin_roundtrip(base_test_env):
"""save_bin + load_bin roundtrip preserves data"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
}
save_bin(test_dir, data)
result = load_bin(test_dir)
assert "sequence" in result
assert "loss_mask" in result
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
def test_mmap_store_load_and_fetch(base_test_env):
"""MmapStore loads bin data and fetches correctly"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
store = StoreFactory.create("bin")
store.load(test_dir)
assert len(store) == 200
assert "sequence" in store.keys
result = store.fetch(10, 20, "sequence")
assert result.tolist() == data["sequence"][0][10:20].tolist()
def test_mmap_dataset_load(base_test_env):
"""DatasetFactory.load auto-detects bin format"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) > 0
assert dataset.count == 200
assert dataset[0]["input_ids"].shape[0] == 64
def test_normalize_empty_key():
"""_normalize with empty tensor list does not crash"""
store = H5Store()
store._normalize({"sequence": []})
assert len(store) == 0
assert store.keys == ["sequence"]
def test_normalize_mixed_empty_key():
"""_normalize with empty + non-empty keys returns min=0"""
store = H5Store()
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
assert len(store) == 0
assert set(store.keys) == {"sequence", "loss_mask"}
def test_grpo_dataset_dtype(base_test_env):
"""GRPODataset returns correct dtypes"""
test_dir = base_test_env["test_dir"]
seq_len = 100
data = {
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"masks": [torch.ones(seq_len, dtype=torch.int32)],
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_dtype", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
item = dataset[0]
assert item["prompts"].dtype == torch.long
assert item["responses"].dtype == torch.long
assert item["masks"].dtype == torch.bool
assert item["rewards"].dtype == torch.float32
def test_grpo_dataset_load(base_test_env):
"""GRPODataset loads and returns correct keys"""
test_dir = base_test_env["test_dir"]
seq_len = 200
data = {
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"masks": [torch.ones(seq_len, dtype=torch.int64)],
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_test", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
assert len(dataset) > 0
item = dataset[0]
assert "prompts" in item
assert "responses" in item
assert "masks" in item
assert "rewards" in item
assert item["prompts"].shape[0] == 64
assert item["responses"].shape[0] == 64
def test_detect_format_bin_dir(base_test_env):
"""detect_format returns 'bin' for directory with .bin + meta.json"""
test_dir = base_test_env["test_dir"]
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
assert detect_format(test_dir) == "bin"
def test_store_fetch_multi_key(base_test_env):
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
test_dir = base_test_env["test_dir"]
save_h5(
test_dir,
"multi_key",
{
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
"loss_mask": [torch.ones(100, dtype=torch.int64)],
},
)
store = StoreFactory.create("h5")
store.load(test_dir)
result = store.fetch(10, 20, ["sequence", "loss_mask"])
assert isinstance(result, dict)
assert result["sequence"].shape[0] == 10
assert result["loss_mask"].shape[0] == 10
def test_store_fetch_out_of_bounds(base_test_env):
"""Store.fetch raises ValueError for out-of-bounds indices"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
store = StoreFactory.create("h5")
store.load(test_dir)
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(-1, 10, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(0, 51, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(50, 50, "sequence")
def test_dataset_load_explicit_storage_type(base_test_env):
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
assert len(dataset) > 0
assert dataset.count == 200