From 0a708fff24b32f17a859bc6fcf3ba8f03b644aab Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 28 May 2026 14:36:18 +0800 Subject: [PATCH] =?UTF-8?q?docs=20:=20=E6=9B=B4=E6=96=B0=E6=9E=B6=E6=9E=84?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E4=B8=8E=20storage=20=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=EF=BC=8C=E5=90=8C=E6=AD=A5=20Store=20=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore) - architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系 - dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构 - storage.py: module docstring 改用缩进式注释风格 --- assets/docs/architecture.md | 66 ++++++++++++++++--------------------- assets/docs/dataflow.md | 15 +++++---- astrai/dataset/storage.py | 42 ++++++++--------------- 3 files changed, 50 insertions(+), 73 deletions(-) diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index ac7ec18..b74d242 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -107,7 +107,7 @@ classDiagram class BaseDataset { +int window_size +int stride - +Optional[BaseStorage] storage + +Optional[Store] storage +load(load_path, storage_type, tokenizer) +__getitem__(index) +__len__() @@ -129,38 +129,29 @@ classDiagram +__getitem__(index) Dict } - class BaseSegmentFetcher { - +List[Tensor] segments - +List[int] cum_lengths - +int total_length - +fetch_data(begin_idx, end_idx) Tensor - } - - class BaseStorage { - +MultiSegmentFetcher _fetcher + class Store { + +Dict[str, List[Tensor]] _data + +Dict[str, List[int]] _cum + +int _length +keys (property) - +load(load_path, tokenizer) + +load(path, tokenizer) +fetch(begin, end, keys) +__len__() + -_fetch_key(key, begin, end) Tensor + -_normalize(raw) } - class H5Storage { - +load(load_path, tokenizer) - +fetch(begin, end, keys) Dict - +keys() List + class H5Store { + +load(path, tokenizer) } - class JSONStorage { - +load(load_path, tokenizer) - +fetch(begin, end, keys) Dict - +keys() List + class JSONStore { + +load(path, tokenizer) } - class MultiSegmentFetcher { - +Dict multi_fetchers - +List multi_keys - +key_fetch(begin_idx, end_idx, keys) Dict - +fetch_data(begin_idx, end_idx) Dict + class MmapStore { + +List _mmap_refs + +load(path, tokenizer) } class ResumableDistributedSampler { @@ -168,10 +159,10 @@ classDiagram +int iter } - class StorageFactory { + class StoreFactory { +Registry _registry +register(name) decorator - +create(storage_type) BaseStorage + +create(storage_type) Store } class DatasetFactory { @@ -925,8 +916,9 @@ classDiagram BaseDataset <|-- SFTDataset BaseDataset <|-- DPODataset BaseDataset <|-- GRPODataset - BaseStorage <|-- H5Storage - BaseStorage <|-- JSONStorage + Store <|-- H5Store + Store <|-- JSONStore + Store <|-- MmapStore BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopPStrategy @@ -946,7 +938,7 @@ classDiagram BaseFactory <|-- StrategyFactory BaseFactory <|-- SchedulerFactory BaseFactory <|-- CallbackFactory - BaseFactory <|-- StorageFactory + BaseFactory <|-- StoreFactory BaseFactory <|-- ExecutorFactory BaseFactory <|-- ConfigFactory BaseExecutor <|-- NoneExecutor @@ -989,7 +981,7 @@ classDiagram TrainContext o-- BaseExecutor KvcacheView o-- Storage SamplingPipeline o-- BaseSamplingStrategy - BaseDataset o-- BaseStorage + BaseDataset o-- Store %% --- Dependency (uses temporarily) --- TrainConfig ..> BaseStrategy : selects @@ -1003,8 +995,9 @@ classDiagram FFNFactory ..> DeepSeekMoE : creates DecoderBlock ..> AttnFactory : uses DecoderBlock ..> FFNFactory : uses - StorageFactory ..> H5Storage : creates - StorageFactory ..> JSONStorage : creates + StoreFactory ..> H5Store : creates + StoreFactory ..> JSONStore : creates + StoreFactory ..> MmapStore : creates ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> EncoderConfig : creates ExecutorFactory ..> NoneExecutor : creates @@ -1037,7 +1030,6 @@ classDiagram Executor --> AutoModel Executor --> AutoTokenizer TaskManager --> AutoTokenizer - MultiSegmentFetcher --> BaseSegmentFetcher ``` @@ -1047,7 +1039,7 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| | **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | -| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | +| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | @@ -1061,7 +1053,7 @@ classDiagram | Pattern | Classes | Purpose | |---------|---------|---------| -| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation | +| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | @@ -1071,7 +1063,7 @@ classDiagram | **Context** | `TrainContext` | Unified training state bag | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | -| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | +| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | @@ -1083,7 +1075,7 @@ classDiagram 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP -7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` +7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 7208373..ab391d2 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors. ## Overview ``` -Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference +Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference ``` ## Data Preparation -Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups. +Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups. Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: ``` -StorageFactory.create("h5") → H5Storage -StorageFactory.create("json") → JSONStorage +StoreFactory.create("h5") → H5Store +StoreFactory.create("json") → JSONStore +StoreFactory.create("bin") → MmapStore ``` -Both support shared memory via `.share_memory_()`. +H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. ## Data Keys by Training Type @@ -34,8 +35,8 @@ Both support shared memory via `.share_memory_()`. ``` DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer) - → StorageFactory.create(detect_format(path)) - → MultiSegmentFetcher(BaseSegmentFetcher per key) + → StoreFactory.create(detect_format(path)) + → Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] → BaseDataset.__getitem__(idx) → sliding window [begin, end) via get_index(idx) ``` diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index 761e6d0..3ba0cb8 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -1,36 +1,20 @@ """Storage backends for different data formats. -Design ------- - -Three-layer architecture: - -1. **I/O layer** — ``save_*`` / ``load_*`` functions that read/write raw files - (HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment). - These are format-specific, low-level helpers — no abstraction, no state. - -2. **Store (ABC)** — the central abstraction. Each concrete ``Store`` calls the - I/O layer during ``load()``, then **normalizes** multi-segment data into a - single contiguous tensor per key via ``_normalize()``. After that, ``fetch()`` - is just a vanilla slice — no ``bisect``, no segment bookkeeping. - - Data format inside a ``Store``:: - - self._data = {"sequence": Tensor, "loss_mask": Tensor, ...} - self._length = N # min first-dim size across keys, O(1) - -3. **Dataset layer** — ``BaseDataset`` owns a ``Store`` and only calls - ``store.fetch(begin, end, key)``. It never knows whether the data came - from HDF5, JSON, or mmap. +Layers: + - I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin) + return Dict[str, List[Tensor]] — format-specific, no state + - Store (ABC): central abstraction, normalizes multi-segment into + Dict[str, List[Tensor]] per key via _normalize(), + fetch() uses bisect across segments — no forced concat + - Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key) Key properties: - -- **Explicit length**: ``_length`` is set during ``load()`` and exposed via - ``__len__`` (O(1)). No hidden computation inside a fetcher. -- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors. - Multiple DataLoader workers share the same OS page-cache pages. -- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load - time, so fetch-time logic is trivial. + - Multi-segment: segments kept as-is, no forced concatenation — safe for + datasets larger than RAM + - Explicit length: _length = min(total elements across keys), set at load, + __len__ returns O(1) + - Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader + workers share OS page-cache pages """ import bisect