docs : 更新架构文档与 storage 注释,同步 Store 重构

- architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore)
- architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系
- dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构
- storage.py: module docstring 改用缩进式注释风格
This commit is contained in:
ViperEkura 2026-05-28 14:36:18 +08:00
parent 6e150ea6d0
commit 0a708fff24
3 changed files with 50 additions and 73 deletions

View File

@ -107,7 +107,7 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+Optional[BaseStorage] storage
+Optional[Store] storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
@ -129,38 +129,29 @@ classDiagram
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
class Store {
+Dict[str, List[Tensor]] _data
+Dict[str, List[int]] _cum
+int _length
+keys (property)
+load(load_path, tokenizer)
+load(path, tokenizer)
+fetch(begin, end, keys)
+__len__()
-_fetch_key(key, begin, end) Tensor
-_normalize(raw)
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
class H5Store {
+load(path, tokenizer)
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
class JSONStore {
+load(path, tokenizer)
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
class MmapStore {
+List _mmap_refs
+load(path, tokenizer)
}
class ResumableDistributedSampler {
@ -168,10 +159,10 @@ classDiagram
+int iter
}
class StorageFactory {
class StoreFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) BaseStorage
+create(storage_type) Store
}
class DatasetFactory {
@ -925,8 +916,9 @@ classDiagram
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
Store <|-- H5Store
Store <|-- JSONStore
Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
@ -946,7 +938,7 @@ classDiagram
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- StoreFactory
BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor
@ -989,7 +981,7 @@ classDiagram
TrainContext o-- BaseExecutor
KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage
BaseDataset o-- Store
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
@ -1003,8 +995,9 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
StoreFactory ..> H5Store : creates
StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
@ -1037,7 +1030,6 @@ classDiagram
Executor --> AutoModel
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
```
@ -1047,7 +1039,7 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
@ -1061,7 +1053,7 @@ classDiagram
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
@ -1071,7 +1063,7 @@ classDiagram
| **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1083,7 +1075,7 @@ classDiagram
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops

View File

@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview
```
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
```
## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
StorageFactory.create("h5") → H5Storage
StorageFactory.create("json") → JSONStorage
StoreFactory.create("h5") → H5Store
StoreFactory.create("json") → JSONStore
StoreFactory.create("bin") → MmapStore
```
Both support shared memory via `.share_memory_()`.
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type
@ -34,8 +35,8 @@ Both support shared memory via `.share_memory_()`.
```
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
→ StorageFactory.create(detect_format(path))
MultiSegmentFetcher(BaseSegmentFetcher per key)
→ StoreFactory.create(detect_format(path))
Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx)
```

View File

@ -1,36 +1,20 @@
"""Storage backends for different data formats.
Design
------
Three-layer architecture:
1. **I/O layer** ``save_*`` / ``load_*`` functions that read/write raw files
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
These are format-specific, low-level helpers no abstraction, no state.
2. **Store (ABC)** the central abstraction. Each concrete ``Store`` calls the
I/O layer during ``load()``, then **normalizes** multi-segment data into a
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
is just a vanilla slice no ``bisect``, no segment bookkeeping.
Data format inside a ``Store``::
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
self._length = N # min first-dim size across keys, O(1)
3. **Dataset layer** ``BaseDataset`` owns a ``Store`` and only calls
``store.fetch(begin, end, key)``. It never knows whether the data came
from HDF5, JSON, or mmap.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
``__len__`` (O(1)). No hidden computation inside a fetcher.
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
Multiple DataLoader workers share the same OS page-cache pages.
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
time, so fetch-time logic is trivial.
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
"""
import bisect