Compare commits
7 Commits
0a708fff24
...
b37c3d000c
| Author | SHA1 | Date |
|---|---|---|
|
|
b37c3d000c | |
|
|
6031020e37 | |
|
|
c424dfc293 | |
|
|
3a28e52e98 | |
|
|
e371908b54 | |
|
|
7c99da155c | |
|
|
629e72385b |
|
|
@ -65,7 +65,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainConfig {
|
class TrainConfig {
|
||||||
+nn.Module model
|
+Callable[[], nn.Module] model_fn
|
||||||
+str strategy
|
+str strategy
|
||||||
+Dataset dataset
|
+Dataset dataset
|
||||||
+Callable optimizer_fn
|
+Callable optimizer_fn
|
||||||
|
|
@ -108,7 +108,7 @@ classDiagram
|
||||||
+int window_size
|
+int window_size
|
||||||
+int stride
|
+int stride
|
||||||
+Optional[Store] storage
|
+Optional[Store] storage
|
||||||
+load(load_path, storage_type, tokenizer)
|
+load(load_path, storage_type)
|
||||||
+__getitem__(index)
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
}
|
}
|
||||||
|
|
@ -134,7 +134,7 @@ classDiagram
|
||||||
+Dict[str, List[int]] _cum
|
+Dict[str, List[int]] _cum
|
||||||
+int _length
|
+int _length
|
||||||
+keys (property)
|
+keys (property)
|
||||||
+load(path, tokenizer)
|
+load(path)
|
||||||
+fetch(begin, end, keys)
|
+fetch(begin, end, keys)
|
||||||
+__len__()
|
+__len__()
|
||||||
-_fetch_key(key, begin, end) Tensor
|
-_fetch_key(key, begin, end) Tensor
|
||||||
|
|
@ -142,16 +142,12 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class H5Store {
|
class H5Store {
|
||||||
+load(path, tokenizer)
|
+load(path)
|
||||||
}
|
|
||||||
|
|
||||||
class JSONStore {
|
|
||||||
+load(path, tokenizer)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class MmapStore {
|
class MmapStore {
|
||||||
+List _mmap_refs
|
+List _mmap_refs
|
||||||
+load(path, tokenizer)
|
+load(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
|
|
@ -169,7 +165,7 @@ classDiagram
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -180,8 +176,9 @@ classDiagram
|
||||||
+int iteration
|
+int iteration
|
||||||
+dict extra
|
+dict extra
|
||||||
+dict meta
|
+dict meta
|
||||||
|
+dict config
|
||||||
+save(save_dir)
|
+save(save_dir)
|
||||||
+load(save_dir) Checkpoint
|
+load(save_dir, broadcast) Checkpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -189,8 +186,8 @@ classDiagram
|
||||||
class AutoModel {
|
class AutoModel {
|
||||||
+BaseModelConfig config
|
+BaseModelConfig config
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(model_type) decorator
|
+register(name) decorator
|
||||||
+get_component_class(model_type) Type
|
+get_component_class(name) Type
|
||||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||||
+save_pretrained(save_directory)
|
+save_pretrained(save_directory)
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
|
|
@ -204,7 +201,7 @@ classDiagram
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
+Linear lm_head
|
+Linear lm_head
|
||||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
||||||
+load_state_dict(state_dict)
|
+load_state_dict(state_dict, strict, assign)
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,6 +226,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class GQA {
|
class GQA {
|
||||||
|
+int dim
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
|
|
@ -243,6 +241,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class MLA {
|
class MLA {
|
||||||
|
+int dim
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
|
|
@ -303,6 +302,7 @@ classDiagram
|
||||||
+int dim
|
+int dim
|
||||||
+int max_len
|
+int max_len
|
||||||
+float base
|
+float base
|
||||||
|
+Optional[Dict] rope_scaling
|
||||||
+forward(x, position_ids=None) Tensor
|
+forward(x, position_ids=None) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -315,10 +315,10 @@ classDiagram
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int]
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
|
||||||
+set_chat_template(template)
|
+set_chat_template(template)
|
||||||
+load(path)
|
+load(path)
|
||||||
+from_pretrained(path) AutoTokenizer
|
+from_pretrained(path) AutoTokenizer
|
||||||
|
|
@ -326,7 +326,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatTemplate {
|
class ChatTemplate {
|
||||||
+String template_str
|
+str template_str
|
||||||
+render(messages, system_prompt, **extra_variables) str
|
+render(messages, system_prompt, **extra_variables) str
|
||||||
+from_string(template) ChatTemplate
|
+from_string(template) ChatTemplate
|
||||||
}
|
}
|
||||||
|
|
@ -364,6 +364,7 @@ classDiagram
|
||||||
+SchedulerProtocol scheduler
|
+SchedulerProtocol scheduler
|
||||||
+Checkpoint checkpoint
|
+Checkpoint checkpoint
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
|
+dict model_config
|
||||||
+BaseExecutor executor
|
+BaseExecutor executor
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
|
|
@ -377,7 +378,7 @@ classDiagram
|
||||||
|
|
||||||
class TrainContextBuilder {
|
class TrainContextBuilder {
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
+with_checkpoint(checkpoint) TrainContextBuilder
|
+with_resume_dir(resume_dir) TrainContextBuilder
|
||||||
+build() TrainContext
|
+build() TrainContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -472,16 +473,12 @@ classDiagram
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
+bool weight_only
|
+bool weight_only
|
||||||
+Callable state_dict_fn
|
|
||||||
+Callable save_extra_fn
|
+Callable save_extra_fn
|
||||||
+Callable load_extra_fn
|
|
||||||
+_save_checkpoint(context)
|
+_save_checkpoint(context)
|
||||||
+on_train_begin(context)
|
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
+save_extra(context)$
|
+save_extra(context)$
|
||||||
+load_extra(extra, context)$
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProgressBarCallback {
|
class ProgressBarCallback {
|
||||||
|
|
@ -518,7 +515,12 @@ classDiagram
|
||||||
+float lr
|
+float lr
|
||||||
+float momentum
|
+float momentum
|
||||||
+float weight_decay
|
+float weight_decay
|
||||||
|
+bool nesterov
|
||||||
+int ns_steps
|
+int ns_steps
|
||||||
|
+float adamw_lr
|
||||||
|
+tuple adamw_betas
|
||||||
|
+float adamw_eps
|
||||||
|
+float adamw_wd
|
||||||
+step(closure) Optional[float]
|
+step(closure) Optional[float]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -539,6 +541,8 @@ classDiagram
|
||||||
+AutoModel model
|
+AutoModel model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+KVCache page_cache
|
+KVCache page_cache
|
||||||
|
+Optional[str] device
|
||||||
|
+Optional[torch.dtype] dtype
|
||||||
+execute_prefill(tasks, prompt_len, start_pos)
|
+execute_prefill(tasks, prompt_len, start_pos)
|
||||||
+execute_decode(tasks) List[int]
|
+execute_decode(tasks) List[int]
|
||||||
}
|
}
|
||||||
|
|
@ -550,7 +554,9 @@ classDiagram
|
||||||
+bool _running
|
+bool _running
|
||||||
+Thread _loop_thread
|
+Thread _loop_thread
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+str device
|
||||||
|
+torch.dtype dtype
|
||||||
|
+add_task(prompt, **kwargs) str
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
+stop()
|
+stop()
|
||||||
|
|
@ -653,15 +659,19 @@ classDiagram
|
||||||
|
|
||||||
class TaskManager {
|
class TaskManager {
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
|
+int max_batch_size
|
||||||
|
+int max_seq_len
|
||||||
|
+int max_prompt_len
|
||||||
+Deque waiting_queue
|
+Deque waiting_queue
|
||||||
+List active_tasks
|
+List active_tasks
|
||||||
+add_task(prompt, **kwargs) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+remove_task(task_id) List[Task]
|
+remove_task(task_id) List[Task]
|
||||||
+remove_finished_tasks(stop_ids) List[Task]
|
+remove_finished_tasks(stop_ids) List[Task]
|
||||||
+pull_candidates(n) List[Task]
|
+pull_candidates(n) List[Task]
|
||||||
+activate(task)
|
+activate(task)
|
||||||
+return_to_waiting(tasks)
|
+return_to_waiting(tasks)
|
||||||
+get_active_tasks() List[Task]
|
+get_active_tasks() List[Task]
|
||||||
|
+get_stats() Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenerationRequest {
|
class GenerationRequest {
|
||||||
|
|
@ -917,7 +927,6 @@ classDiagram
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
BaseDataset <|-- GRPODataset
|
BaseDataset <|-- GRPODataset
|
||||||
Store <|-- H5Store
|
Store <|-- H5Store
|
||||||
Store <|-- JSONStore
|
|
||||||
Store <|-- MmapStore
|
Store <|-- MmapStore
|
||||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
|
|
@ -996,7 +1005,6 @@ classDiagram
|
||||||
DecoderBlock ..> AttnFactory : uses
|
DecoderBlock ..> AttnFactory : uses
|
||||||
DecoderBlock ..> FFNFactory : uses
|
DecoderBlock ..> FFNFactory : uses
|
||||||
StoreFactory ..> H5Store : creates
|
StoreFactory ..> H5Store : creates
|
||||||
StoreFactory ..> JSONStore : creates
|
|
||||||
StoreFactory ..> MmapStore : creates
|
StoreFactory ..> MmapStore : creates
|
||||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||||
ConfigFactory ..> EncoderConfig : creates
|
ConfigFactory ..> EncoderConfig : creates
|
||||||
|
|
@ -1063,7 +1071,7 @@ classDiagram
|
||||||
| **Context** | `TrainContext` | Unified training state bag |
|
| **Context** | `TrainContext` | Unified training state bag |
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
||||||
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||||
|
|
||||||
|
|
@ -1075,10 +1083,10 @@ classDiagram
|
||||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
||||||
|
|
||||||
> Document Update Time: 2026-05-24
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
|
|
@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
```
|
```
|
||||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
||||||
```
|
```
|
||||||
|
|
||||||
## Data Preparation
|
## Data Preparation
|
||||||
|
|
||||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||||
|
|
||||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
|
|
||||||
```
|
```
|
||||||
StoreFactory.create("h5") → H5Store
|
StoreFactory.create("h5") → H5Store
|
||||||
StoreFactory.create("json") → JSONStore
|
|
||||||
StoreFactory.create("bin") → MmapStore
|
StoreFactory.create("bin") → MmapStore
|
||||||
```
|
```
|
||||||
|
|
||||||
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||||
|
|
||||||
## Data Keys by Training Type
|
## Data Keys by Training Type
|
||||||
|
|
||||||
|
|
@ -34,7 +33,7 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us
|
||||||
## Dataset Architecture
|
## Dataset Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
|
||||||
→ StoreFactory.create(detect_format(path))
|
→ StoreFactory.create(detect_format(path))
|
||||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||||
→ BaseDataset.__getitem__(idx)
|
→ BaseDataset.__getitem__(idx)
|
||||||
|
|
@ -55,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokeniz
|
||||||
|
|
||||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,12 @@ Six classes working together:
|
||||||
|
|
||||||
```
|
```
|
||||||
KVCache (facade)
|
KVCache (facade)
|
||||||
├── Allocator bitmask-based page allocator + ref-count + LRU eviction
|
├── PagePool orchestrates page allocation + prefix matching
|
||||||
├── PrefixCache hash-based prefix matching (page_hash via rolling hash)
|
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
||||||
├── PagePool orchestrates Allocator + PrefixCache
|
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
||||||
├── TaskTable maps task_id → page_table + cached token count
|
├── TaskTable maps task_id → page_table + cached token count
|
||||||
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
||||||
└── KvcacheView bundles Storage + page_table + total_len for attention layers
|
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
||||||
```
|
```
|
||||||
|
|
||||||
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
||||||
|
|
@ -40,7 +40,10 @@ KVCache (facade)
|
||||||
## Sampling (Strategy Pattern)
|
## Sampling (Strategy Pattern)
|
||||||
|
|
||||||
```
|
```
|
||||||
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
BaseSamplingStrategy (ABC)
|
||||||
|
├── TemperatureStrategy
|
||||||
|
├── TopKStrategy
|
||||||
|
└── TopPStrategy
|
||||||
```
|
```
|
||||||
|
|
||||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||||
|
|
@ -50,11 +53,12 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ProtocolHandler: # concrete orchestrator
|
class ProtocolHandler: # concrete orchestrator
|
||||||
def handle(self, request):
|
def __init__(self, request, engine, builder): ...
|
||||||
|
async def handle(self):
|
||||||
prompt, ctx, stops = builder.prepare(request, engine)
|
prompt, ctx, stops = builder.prepare(request, engine)
|
||||||
agen = engine.generate_async(prompt, ...)
|
agen = engine.generate_async(prompt, ...)
|
||||||
if stream: self._handle_stream(agen, ctx, stops)
|
if stream: self._handle_stream(agen, ctx, stops)
|
||||||
else: self._handle_non_stream(agen, ctx, stops)
|
else: return await self._handle_non_stream(agen, ctx, stops)
|
||||||
```
|
```
|
||||||
|
|
||||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||||
|
|
@ -96,12 +100,14 @@ Response:
|
||||||
{
|
{
|
||||||
"id": "chatcmpl-abc123",
|
"id": "chatcmpl-abc123",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
"created": 1717000000,
|
||||||
|
"model": "astrai",
|
||||||
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]`
|
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
||||||
|
|
||||||
### Anthropic
|
### Anthropic
|
||||||
|
|
||||||
|
|
@ -121,7 +127,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
||||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||||
| `top_k` | int | 50 | Top-k count |
|
| `top_k` | int | 50 | Top-k count |
|
||||||
| `max_tokens` | int | None | Max generation length |
|
| `max_tokens` | Optional[int] | None | Max generation length |
|
||||||
| `stream` | bool | False | Stream output |
|
| `stream` | bool | False | Stream output |
|
||||||
|
|
||||||
## Engine API
|
## Engine API
|
||||||
|
|
@ -139,4 +145,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,8 @@ on_train_begin
|
||||||
on_batch_begin
|
on_batch_begin
|
||||||
with executor.accumulate(model):
|
with executor.accumulate(model):
|
||||||
loss = strategy(batch)
|
loss = strategy(batch)
|
||||||
(loss / grad_accum_steps).backward()
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
|
executor.backward(stand_loss)
|
||||||
iteration += 1
|
iteration += 1
|
||||||
on_batch_end
|
on_batch_end
|
||||||
|
|
||||||
|
|
@ -82,6 +83,7 @@ on_train_begin
|
||||||
on_optimizer_step
|
on_optimizer_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
if scheduler:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
|
|
@ -169,20 +171,20 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
||||||
```
|
```
|
||||||
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
||||||
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
|
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
||||||
└── load(save_dir) broadcasts metadata from rank-0
|
└── load(save_dir) broadcasts metadata from rank-0
|
||||||
```
|
```
|
||||||
|
|
||||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||||
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
|
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
||||||
|
|
||||||
## TrainContextBuilder (Builder Pattern)
|
## TrainContextBuilder (Builder Pattern)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
context = (
|
context = (
|
||||||
TrainContextBuilder(config)
|
TrainContextBuilder(config)
|
||||||
.with_checkpoint(checkpoint)
|
.with_resume_dir(resume_dir)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||||||
|
|
@ -222,4 +224,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-24
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,21 @@ class BaseConfig:
|
||||||
d[fld.name] = v
|
d[fld.name] = v
|
||||||
elif v is None:
|
elif v is None:
|
||||||
d[fld.name] = None
|
d[fld.name] = None
|
||||||
elif isinstance(v, (dict, list)):
|
elif isinstance(v, (dict, list, tuple)):
|
||||||
try:
|
try:
|
||||||
json.dumps(v)
|
val = list(v) if isinstance(v, tuple) else v
|
||||||
d[fld.name] = v
|
json.dumps(val)
|
||||||
|
d[fld.name] = val
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
elif isinstance(v, BaseConfig):
|
||||||
|
d[fld.name] = v.to_dict()
|
||||||
|
elif hasattr(v, "__dataclass_fields__"):
|
||||||
|
sub = {}
|
||||||
|
for f in fields(v):
|
||||||
|
a = getattr(v, f.name)
|
||||||
|
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
||||||
|
d[fld.name] = sub
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,14 @@ from astrai.dataset.dataset import (
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
H5Store,
|
||||||
JSONStore,
|
|
||||||
MmapStore,
|
MmapStore,
|
||||||
Store,
|
Store,
|
||||||
StoreFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
json_to_bin,
|
|
||||||
load_bin,
|
load_bin,
|
||||||
load_h5,
|
load_h5,
|
||||||
load_json,
|
|
||||||
save_bin,
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
save_json,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -25,15 +21,11 @@ __all__ = [
|
||||||
"Store",
|
"Store",
|
||||||
"StoreFactory",
|
"StoreFactory",
|
||||||
"H5Store",
|
"H5Store",
|
||||||
"JSONStore",
|
|
||||||
"MmapStore",
|
"MmapStore",
|
||||||
"detect_format",
|
"detect_format",
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
"save_json",
|
|
||||||
"load_json",
|
|
||||||
"save_bin",
|
"save_bin",
|
||||||
"load_bin",
|
"load_bin",
|
||||||
"json_to_bin",
|
|
||||||
"ResumableDistributedSampler",
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -48,17 +48,15 @@ class BaseDataset(Dataset, ABC):
|
||||||
f"Missing: {missing}"
|
f"Missing: {missing}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
def load(self, load_path: str, storage_type: Optional[str] = None):
|
||||||
"""Load dataset from the given path.
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
Auto-detects the storage format if not specified.
|
Auto-detects the storage format if not specified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
load_path: Path to the data directory or file
|
load_path: Path to the data directory or file
|
||||||
storage_type: Force a specific storage type ("h5", "json"),
|
storage_type: Force a specific storage type ("h5", "bin"),
|
||||||
or None for auto-detection
|
or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
|
||||||
in JSON files. Ignored for HDF5.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: If the loaded storage is missing required keys.
|
KeyError: If the loaded storage is missing required keys.
|
||||||
|
|
@ -67,18 +65,9 @@ class BaseDataset(Dataset, ABC):
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = StoreFactory.create(storage_type)
|
self.storage = StoreFactory.create(storage_type)
|
||||||
self._load_path = load_path
|
self._load_path = load_path
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self.storage.load(load_path)
|
||||||
self._validate_keys()
|
self._validate_keys()
|
||||||
|
|
||||||
def load_json(self, load_path: str, tokenizer=None):
|
|
||||||
"""Load dataset from JSON files explicitly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
load_path: Path to the JSON data file or directory
|
|
||||||
tokenizer: Optional tokenizer callable for raw text JSON.
|
|
||||||
"""
|
|
||||||
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||||
|
|
@ -175,7 +164,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: Optional[int] = None,
|
stride: Optional[int] = None,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
tokenizer=None,
|
|
||||||
) -> "BaseDataset":
|
) -> "BaseDataset":
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
|
|
@ -184,8 +172,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
load_path: Path to the data file
|
load_path: Path to the data file
|
||||||
window_size: Window size for data sampling
|
window_size: Window size for data sampling
|
||||||
stride: Stride between consecutive samples (default: same as window_size)
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
storage_type: Storage type ("h5", "json") or None for auto-detection
|
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded dataset instance
|
Loaded dataset instance
|
||||||
|
|
@ -194,7 +181,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
stride = window_size
|
stride = window_size
|
||||||
|
|
||||||
dataset = cls.create(train_type, window_size, stride)
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
dataset.load(load_path, storage_type=storage_type)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
@ -306,9 +293,11 @@ class GRPODataset(BaseDataset):
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
|
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
||||||
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Storage backends for different data formats.
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
Layers:
|
Layers:
|
||||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
|
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
||||||
return Dict[str, List[Tensor]] — format-specific, no state
|
return Dict[str, List[Tensor]] — format-specific, no state
|
||||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||||
Dict[str, List[Tensor]] per key via _normalize(),
|
Dict[str, List[Tensor]] per key via _normalize(),
|
||||||
|
|
@ -22,7 +22,7 @@ import json
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -68,56 +68,6 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
return tensor_group
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
|
||||||
json_data = {}
|
|
||||||
for key, tensors in tensor_group.items():
|
|
||||||
json_data[key] = [tensor.tolist() for tensor in tensors]
|
|
||||||
with open(full_file_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(json_data, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
def load_json(
|
|
||||||
file_path: str,
|
|
||||||
share_memory: bool = True,
|
|
||||||
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
|
||||||
) -> Dict[str, List[Tensor]]:
|
|
||||||
"""Load tensor data from JSON files.
|
|
||||||
|
|
||||||
Supports two modes:
|
|
||||||
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
|
||||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
|
||||||
at load time. A ``tokenizer`` receives a str and returns List[int].
|
|
||||||
|
|
||||||
Non-data JSON files (e.g. config.json) with scalar/object values are
|
|
||||||
silently skipped.
|
|
||||||
"""
|
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
|
||||||
root_path = Path(file_path)
|
|
||||||
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
|
||||||
for json_file in json_files:
|
|
||||||
with open(json_file, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
continue
|
|
||||||
for key, sequences in data.items():
|
|
||||||
if not isinstance(sequences, list):
|
|
||||||
continue
|
|
||||||
tensors = []
|
|
||||||
for seq in sequences:
|
|
||||||
if tokenizer is not None and isinstance(seq, str):
|
|
||||||
seq = tokenizer(seq)
|
|
||||||
tensor = torch.tensor(seq, dtype=torch.long)
|
|
||||||
if share_memory:
|
|
||||||
tensor = tensor.share_memory_()
|
|
||||||
tensors.append(tensor)
|
|
||||||
if tensor_group.get(key) is None:
|
|
||||||
tensor_group[key] = []
|
|
||||||
tensor_group[key].extend(tensors)
|
|
||||||
return tensor_group
|
|
||||||
|
|
||||||
|
|
||||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
meta = {}
|
meta = {}
|
||||||
|
|
@ -125,31 +75,25 @@ def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
cat = torch.cat(tensors, dim=0)
|
cat = torch.cat(tensors, dim=0)
|
||||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||||
save_json(meta, os.path.join(file_path, "meta.json"))
|
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
||||||
|
json.dump(meta, f)
|
||||||
|
|
||||||
|
|
||||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||||
meta = load_json(os.path.join(file_path, "meta.json"))
|
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
segments: Dict[str, List[Tensor]] = {}
|
segments: Dict[str, List[Tensor]] = {}
|
||||||
for key, info in meta.items():
|
for key, info in meta.items():
|
||||||
arr = np.memmap(
|
arr = np.memmap(
|
||||||
os.path.join(file_path, f"{key}.bin"),
|
os.path.join(file_path, f"{key}.bin"),
|
||||||
dtype=info["dtype"],
|
dtype=info["dtype"],
|
||||||
mode="r",
|
mode="r+",
|
||||||
shape=tuple(info["shape"]),
|
shape=tuple(info["shape"]),
|
||||||
)
|
)
|
||||||
segments[key] = [torch.from_numpy(arr)]
|
segments[key] = [torch.from_numpy(arr)]
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
|
|
||||||
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
|
|
||||||
merged = {}
|
|
||||||
for key, tensors in segments.items():
|
|
||||||
merged[key] = [torch.cat(tensors, dim=0)]
|
|
||||||
save_bin(bin_path, merged)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_format(load_path: str) -> str:
|
def detect_format(load_path: str) -> str:
|
||||||
"""Auto-detect storage format from files in the directory.
|
"""Auto-detect storage format from files in the directory.
|
||||||
|
|
||||||
|
|
@ -157,7 +101,7 @@ def detect_format(load_path: str) -> str:
|
||||||
load_path: Directory or file path
|
load_path: Directory or file path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Format string ("h5", "bin", or "json")
|
Format string ("h5" or "bin")
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If no supported data files are found
|
FileNotFoundError: If no supported data files are found
|
||||||
|
|
@ -167,8 +111,6 @@ def detect_format(load_path: str) -> str:
|
||||||
suffix = root.suffix.lower()
|
suffix = root.suffix.lower()
|
||||||
if suffix in (".h5", ".hdf5"):
|
if suffix in (".h5", ".hdf5"):
|
||||||
return "h5"
|
return "h5"
|
||||||
if suffix in (".json", ".jsonl"):
|
|
||||||
return "json"
|
|
||||||
raise ValueError(f"Unsupported file format: {suffix}")
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||||
|
|
@ -177,9 +119,6 @@ def detect_format(load_path: str) -> str:
|
||||||
bin_files = list(root.rglob("*.bin"))
|
bin_files = list(root.rglob("*.bin"))
|
||||||
if bin_files and (root / "meta.json").exists():
|
if bin_files and (root / "meta.json").exists():
|
||||||
return "bin"
|
return "bin"
|
||||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
|
||||||
if json_files:
|
|
||||||
return "json"
|
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -200,7 +139,7 @@ class Store(ABC):
|
||||||
self._length: int = 0
|
self._length: int = 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, path: str, tokenizer=None) -> None:
|
def load(self, path: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -257,7 +196,11 @@ class Store(ABC):
|
||||||
total += t.shape[0]
|
total += t.shape[0]
|
||||||
cum.append(total)
|
cum.append(total)
|
||||||
self._cum[key] = cum
|
self._cum[key] = cum
|
||||||
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
|
self._length = (
|
||||||
|
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
||||||
|
if self._cum
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StoreFactory(BaseFactory["Store"]):
|
class StoreFactory(BaseFactory["Store"]):
|
||||||
|
|
@ -280,24 +223,10 @@ class StoreFactory(BaseFactory["Store"]):
|
||||||
class H5Store(Store):
|
class H5Store(Store):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
def load(self, path: str, tokenizer=None):
|
def load(self, path: str):
|
||||||
self._normalize(load_h5(path))
|
self._normalize(load_h5(path))
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("json")
|
|
||||||
class JSONStore(Store):
|
|
||||||
"""JSON-based storage backend.
|
|
||||||
|
|
||||||
Supports two modes:
|
|
||||||
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
|
||||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
|
||||||
callable (str -> List[int]) at load time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load(self, path: str, tokenizer=None):
|
|
||||||
self._normalize(load_json(path, tokenizer=tokenizer))
|
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("bin")
|
@StoreFactory.register("bin")
|
||||||
class MmapStore(Store):
|
class MmapStore(Store):
|
||||||
"""Memory-mapped binary storage backend.
|
"""Memory-mapped binary storage backend.
|
||||||
|
|
@ -313,7 +242,7 @@ class MmapStore(Store):
|
||||||
<key>.bin # raw numpy array, one per key
|
<key>.bin # raw numpy array, one per key
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self, path: str, tokenizer=None):
|
def load(self, path: str):
|
||||||
self._mmap_refs = []
|
self._mmap_refs = []
|
||||||
raw = load_bin(path)
|
raw = load_bin(path)
|
||||||
self._normalize(raw)
|
self._normalize(raw)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -16,29 +16,50 @@ _CONFIG_FILE = "config.json"
|
||||||
_WEIGHTS_FILE = "model.safetensors"
|
_WEIGHTS_FILE = "model.safetensors"
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict: dict, path: str | Path):
|
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
||||||
st.save_file(state_dict, str(path))
|
st.save_file(state_dict, str(path))
|
||||||
|
|
||||||
|
|
||||||
def load_safetensors(path: str | Path) -> dict:
|
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
return st.load_file(str(path))
|
return st.load_file(str(path))
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
state_dict = st.load_file(str(path))
|
||||||
|
else:
|
||||||
|
state_dict = {}
|
||||||
|
tmp = [state_dict]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
def save_json(data: dict, path: str | Path):
|
|
||||||
|
def save_json(data: dict, path: Union[str, Path]):
|
||||||
with open(str(path), "w") as f:
|
with open(str(path), "w") as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: str | Path) -> dict:
|
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
with open(str(path), "r") as f:
|
with open(str(path), "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
tmp = [data]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
def save_torch(obj: Any, path: str | Path):
|
|
||||||
|
def save_torch(obj: Any, path: Union[str, Path]):
|
||||||
torch.save(obj, str(path))
|
torch.save(obj, str(path))
|
||||||
|
|
||||||
|
|
||||||
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
||||||
if not broadcast or not dist.is_initialized():
|
if not broadcast or not dist.is_initialized():
|
||||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
|
@ -76,28 +97,18 @@ def load_model_config(save_directory: str) -> dict:
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(save_directory: str) -> dict:
|
def load_model_weights(save_directory: str) -> dict:
|
||||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
def _get_meta(save_path: Path) -> dict:
|
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
meta = {}
|
path = Path(path)
|
||||||
if get_rank() == 0:
|
|
||||||
meta = load_json(save_path / _META_FILE)
|
|
||||||
if dist.is_initialized():
|
|
||||||
meta_list = [meta]
|
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
|
||||||
meta = meta_list[0]
|
|
||||||
return meta
|
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
|
||||||
if not broadcast or not dist.is_initialized():
|
if not broadcast or not dist.is_initialized():
|
||||||
return load_safetensors(save_path / _WEIGHTS_FILE)
|
return load_safetensors(path)
|
||||||
|
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
state_dict = load_safetensors(path)
|
||||||
specs: List[Tuple[str, List[int], str]] = [
|
specs = [
|
||||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||||
for k in sorted(state_dict)
|
for k in sorted(state_dict)
|
||||||
]
|
]
|
||||||
|
|
@ -128,6 +139,7 @@ class Checkpoint:
|
||||||
iteration: int = 0
|
iteration: int = 0
|
||||||
extra: Dict[str, Any] = field(default_factory=dict)
|
extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
config: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def save(self, save_dir: str):
|
def save(self, save_dir: str):
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
@ -143,6 +155,7 @@ class Checkpoint:
|
||||||
**self.meta,
|
**self.meta,
|
||||||
}
|
}
|
||||||
save_json(meta, save_path / _META_FILE)
|
save_json(meta, save_path / _META_FILE)
|
||||||
|
save_json(self.config, save_path / _CONFIG_FILE)
|
||||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||||
for key, value in self.extra.items():
|
for key, value in self.extra.items():
|
||||||
save_torch(value, save_path / f"{key}.pt")
|
save_torch(value, save_path / f"{key}.pt")
|
||||||
|
|
@ -151,8 +164,9 @@ class Checkpoint:
|
||||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = _get_meta(save_path)
|
meta = load_json(save_path / _META_FILE, broadcast)
|
||||||
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
||||||
|
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
||||||
|
|
||||||
extra = {}
|
extra = {}
|
||||||
for f in sorted(save_path.iterdir()):
|
for f in sorted(save_path.iterdir()):
|
||||||
|
|
@ -164,4 +178,5 @@ class Checkpoint:
|
||||||
epoch=meta.get("epoch", 0),
|
epoch=meta.get("epoch", 0),
|
||||||
iteration=meta.get("iteration", 0),
|
iteration=meta.get("iteration", 0),
|
||||||
extra=extra,
|
extra=extra,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -137,23 +137,17 @@ class CheckpointCallback(TrainCallback):
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
|
||||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
# All ranks gather state_dict — collective for FSDP, local for DDP
|
unwrapped = context.executor.unwrap_model(context.model)
|
||||||
state_dict = (
|
state_dict = unwrapped.state_dict()
|
||||||
self.state_dict_fn(context.model)
|
|
||||||
if self.state_dict_fn
|
|
||||||
else context.model.state_dict()
|
|
||||||
)
|
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
if get_rank() == 0:
|
if get_rank() == 0:
|
||||||
|
|
@ -166,7 +160,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
iteration=context.iteration,
|
iteration=context.iteration,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
meta=context.config.to_dict(),
|
config=context.model_config,
|
||||||
)
|
)
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from astrai.model.components.lora import inject_lora
|
||||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||||
from astrai.serialization import Checkpoint, load_model_weights
|
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,6 +24,7 @@ class TrainContext:
|
||||||
scheduler: SchedulerProtocol = field(default=None)
|
scheduler: SchedulerProtocol = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
config: TrainConfig = field(default=None)
|
config: TrainConfig = field(default=None)
|
||||||
|
model_config: dict = field(default_factory=dict)
|
||||||
executor: BaseExecutor = field(default=None)
|
executor: BaseExecutor = field(default=None)
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
|
|
@ -62,11 +63,21 @@ class TrainContextBuilder:
|
||||||
model = cfg.model_fn()
|
model = cfg.model_fn()
|
||||||
model = model.to(device=device)
|
model = model.to(device=device)
|
||||||
|
|
||||||
|
model_config = {}
|
||||||
|
if self._resume_dir:
|
||||||
|
config_path = Path(self._resume_dir) / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
model_config = load_json(config_path)
|
||||||
|
|
||||||
|
if not model_config and hasattr(model, "config"):
|
||||||
|
model_config = model.config.to_dict()
|
||||||
|
|
||||||
context = TrainContext(
|
context = TrainContext(
|
||||||
model=model,
|
model=model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
config=cfg,
|
config=cfg,
|
||||||
|
model_config=model_config,
|
||||||
executor=executor,
|
executor=executor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -75,13 +86,15 @@ class TrainContextBuilder:
|
||||||
if (resume_path / "meta.json").exists():
|
if (resume_path / "meta.json").exists():
|
||||||
checkpoint = Checkpoint.load(self._resume_dir)
|
checkpoint = Checkpoint.load(self._resume_dir)
|
||||||
state_dict = checkpoint.state_dict
|
state_dict = checkpoint.state_dict
|
||||||
|
if checkpoint.config:
|
||||||
|
context.model_config = checkpoint.config
|
||||||
else:
|
else:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
state_dict = load_model_weights(self._resume_dir)
|
state_dict = load_model_weights(self._resume_dir)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
|
context.epoch = cfg.start_epoch
|
||||||
context.iteration = max(checkpoint.iteration, cfg.start_batch)
|
context.iteration = cfg.start_batch
|
||||||
context.checkpoint = checkpoint
|
context.checkpoint = checkpoint
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,11 @@ import torch
|
||||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
H5Store,
|
||||||
|
MmapStore,
|
||||||
StoreFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_json,
|
load_bin,
|
||||||
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -155,111 +157,6 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
assert len(dataset) > len(default_stride_dataset)
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
||||||
|
|
||||||
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tokenizer_fn(tokenizer):
|
|
||||||
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
|
||||||
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_seq_dataset_from_json_text(base_test_env):
|
|
||||||
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_text")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
"hello world this is a test sentence for tokenizer",
|
|
||||||
"another sentence with different words and tokens",
|
|
||||||
"machine learning is fascinating and powerful",
|
|
||||||
]
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "seq_data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="seq",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=16,
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count > 0
|
|
||||||
assert "sequence" in dataset.keys
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert "input_ids" in item
|
|
||||||
assert "target_ids" in item
|
|
||||||
assert item["input_ids"].shape[0] == 16
|
|
||||||
|
|
||||||
|
|
||||||
def test_sft_dataset_from_json_text(base_test_env):
|
|
||||||
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_sft")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
"user asks a question about the weather",
|
|
||||||
"assistant provides a helpful response to the user",
|
|
||||||
]
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "sft_data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(
|
|
||||||
{"sequence": texts, "loss_mask": texts},
|
|
||||||
f,
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="sft",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=16,
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert "loss_mask" in item
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_storage_explicit_tokenizer(base_test_env):
|
|
||||||
"""Test explicit JSON storage with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_explicit")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
token_count = len(tokenizer_fn(texts[0]))
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="seq",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=32,
|
|
||||||
storage_type="json",
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == token_count
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_count_property(base_test_env):
|
def test_dataset_count_property(base_test_env):
|
||||||
"""Test the count property returns correct raw token count"""
|
"""Test the count property returns correct raw token count"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
@ -334,25 +231,6 @@ def test_store_fetch_begin_equals_end(base_test_env):
|
||||||
assert result.numel() == 0
|
assert result.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_store_empty_data_len(base_test_env):
|
|
||||||
"""Store loaded with empty data has __len__ == 0"""
|
|
||||||
import os
|
|
||||||
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "empty_store")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
with open(os.path.join(data_dir, "data.json"), "w") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3]]}, f)
|
|
||||||
|
|
||||||
store = StoreFactory.create("json")
|
|
||||||
store.load(data_dir)
|
|
||||||
assert len(store) > 0
|
|
||||||
|
|
||||||
empty_store = H5Store()
|
|
||||||
assert len(empty_store) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_fetch_before_load():
|
def test_store_fetch_before_load():
|
||||||
"""Store.fetch before load raises RuntimeError"""
|
"""Store.fetch before load raises RuntimeError"""
|
||||||
store = H5Store()
|
store = H5Store()
|
||||||
|
|
@ -382,40 +260,6 @@ def test_create_store_invalid_type():
|
||||||
StoreFactory.create("parquet")
|
StoreFactory.create("parquet")
|
||||||
|
|
||||||
|
|
||||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
|
||||||
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_pretok")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.json")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == 10
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
|
||||||
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_json_skips_config_file(base_test_env):
|
|
||||||
"""load_json skips scalar-value config files"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
|
||||||
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
|
||||||
|
|
||||||
with open(os.path.join(test_dir, "data.json"), "w") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
|
||||||
|
|
||||||
result = load_json(test_dir)
|
|
||||||
assert "sequence" in result
|
|
||||||
assert "vocab_size" not in result
|
|
||||||
assert len(result["sequence"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_multi_segment_concat(base_test_env):
|
def test_store_multi_segment_concat(base_test_env):
|
||||||
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
||||||
import os
|
import os
|
||||||
|
|
@ -436,3 +280,166 @@ def test_store_multi_segment_concat(base_test_env):
|
||||||
assert len(store) == 9
|
assert len(store) == 9
|
||||||
result = store.fetch(2, 7, "sequence")
|
result = store.fetch(2, 7, "sequence")
|
||||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_bin_roundtrip(base_test_env):
|
||||||
|
"""save_bin + load_bin roundtrip preserves data"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
|
||||||
|
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
result = load_bin(test_dir)
|
||||||
|
|
||||||
|
assert "sequence" in result
|
||||||
|
assert "loss_mask" in result
|
||||||
|
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
|
||||||
|
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmap_store_load_and_fetch(base_test_env):
|
||||||
|
"""MmapStore loads bin data and fetches correctly"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
|
||||||
|
store = StoreFactory.create("bin")
|
||||||
|
store.load(test_dir)
|
||||||
|
assert len(store) == 200
|
||||||
|
assert "sequence" in store.keys
|
||||||
|
|
||||||
|
result = store.fetch(10, 20, "sequence")
|
||||||
|
assert result.tolist() == data["sequence"][0][10:20].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmap_dataset_load(base_test_env):
|
||||||
|
"""DatasetFactory.load auto-detects bin format"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == 200
|
||||||
|
assert dataset[0]["input_ids"].shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_empty_key():
|
||||||
|
"""_normalize with empty tensor list does not crash"""
|
||||||
|
store = H5Store()
|
||||||
|
store._normalize({"sequence": []})
|
||||||
|
assert len(store) == 0
|
||||||
|
assert store.keys == ["sequence"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_mixed_empty_key():
|
||||||
|
"""_normalize with empty + non-empty keys returns min=0"""
|
||||||
|
store = H5Store()
|
||||||
|
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
|
||||||
|
assert len(store) == 0
|
||||||
|
assert set(store.keys) == {"sequence", "loss_mask"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_dataset_dtype(base_test_env):
|
||||||
|
"""GRPODataset returns correct dtypes"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
seq_len = 100
|
||||||
|
data = {
|
||||||
|
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||||
|
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||||
|
"masks": [torch.ones(seq_len, dtype=torch.int32)],
|
||||||
|
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
|
||||||
|
}
|
||||||
|
save_h5(test_dir, "grpo_dtype", data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
|
||||||
|
item = dataset[0]
|
||||||
|
|
||||||
|
assert item["prompts"].dtype == torch.long
|
||||||
|
assert item["responses"].dtype == torch.long
|
||||||
|
assert item["masks"].dtype == torch.bool
|
||||||
|
assert item["rewards"].dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_dataset_load(base_test_env):
|
||||||
|
"""GRPODataset loads and returns correct keys"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
seq_len = 200
|
||||||
|
data = {
|
||||||
|
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||||
|
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||||
|
"masks": [torch.ones(seq_len, dtype=torch.int64)],
|
||||||
|
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
|
||||||
|
}
|
||||||
|
save_h5(test_dir, "grpo_test", data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
item = dataset[0]
|
||||||
|
assert "prompts" in item
|
||||||
|
assert "responses" in item
|
||||||
|
assert "masks" in item
|
||||||
|
assert "rewards" in item
|
||||||
|
assert item["prompts"].shape[0] == 64
|
||||||
|
assert item["responses"].shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_format_bin_dir(base_test_env):
|
||||||
|
"""detect_format returns 'bin' for directory with .bin + meta.json"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
|
||||||
|
assert detect_format(test_dir) == "bin"
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_fetch_multi_key(base_test_env):
|
||||||
|
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(
|
||||||
|
test_dir,
|
||||||
|
"multi_key",
|
||||||
|
{
|
||||||
|
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
|
||||||
|
"loss_mask": [torch.ones(100, dtype=torch.int64)],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(test_dir)
|
||||||
|
result = store.fetch(10, 20, ["sequence", "loss_mask"])
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["sequence"].shape[0] == 10
|
||||||
|
assert result["loss_mask"].shape[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_fetch_out_of_bounds(base_test_env):
|
||||||
|
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
|
||||||
|
|
||||||
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(test_dir)
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(-1, 10, "sequence")
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(0, 51, "sequence")
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(50, 50, "sequence")
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_load_explicit_storage_type(base_test_env):
|
||||||
|
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == 200
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue