docs : 更新架构文档与 storage 注释,同步 Store 重构
- architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore) - architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系 - dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构 - storage.py: module docstring 改用缩进式注释风格
This commit is contained in:
parent
6e150ea6d0
commit
0a708fff24
|
|
@ -107,7 +107,7 @@ classDiagram
|
|||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+Optional[BaseStorage] storage
|
||||
+Optional[Store] storage
|
||||
+load(load_path, storage_type, tokenizer)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
|
|
@ -129,38 +129,29 @@ classDiagram
|
|||
+__getitem__(index) Dict
|
||||
}
|
||||
|
||||
class BaseSegmentFetcher {
|
||||
+List[Tensor] segments
|
||||
+List[int] cum_lengths
|
||||
+int total_length
|
||||
+fetch_data(begin_idx, end_idx) Tensor
|
||||
}
|
||||
|
||||
class BaseStorage {
|
||||
+MultiSegmentFetcher _fetcher
|
||||
class Store {
|
||||
+Dict[str, List[Tensor]] _data
|
||||
+Dict[str, List[int]] _cum
|
||||
+int _length
|
||||
+keys (property)
|
||||
+load(load_path, tokenizer)
|
||||
+load(path, tokenizer)
|
||||
+fetch(begin, end, keys)
|
||||
+__len__()
|
||||
-_fetch_key(key, begin, end) Tensor
|
||||
-_normalize(raw)
|
||||
}
|
||||
|
||||
class H5Storage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
class H5Store {
|
||||
+load(path, tokenizer)
|
||||
}
|
||||
|
||||
class JSONStorage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
class JSONStore {
|
||||
+load(path, tokenizer)
|
||||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict multi_fetchers
|
||||
+List multi_keys
|
||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||
+fetch_data(begin_idx, end_idx) Dict
|
||||
class MmapStore {
|
||||
+List _mmap_refs
|
||||
+load(path, tokenizer)
|
||||
}
|
||||
|
||||
class ResumableDistributedSampler {
|
||||
|
|
@ -168,10 +159,10 @@ classDiagram
|
|||
+int iter
|
||||
}
|
||||
|
||||
class StorageFactory {
|
||||
class StoreFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(storage_type) BaseStorage
|
||||
+create(storage_type) Store
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
|
|
@ -925,8 +916,9 @@ classDiagram
|
|||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
BaseDataset <|-- GRPODataset
|
||||
BaseStorage <|-- H5Storage
|
||||
BaseStorage <|-- JSONStorage
|
||||
Store <|-- H5Store
|
||||
Store <|-- JSONStore
|
||||
Store <|-- MmapStore
|
||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||
BaseSamplingStrategy <|-- TopKStrategy
|
||||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
|
|
@ -946,7 +938,7 @@ classDiagram
|
|||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
BaseFactory <|-- CallbackFactory
|
||||
BaseFactory <|-- StorageFactory
|
||||
BaseFactory <|-- StoreFactory
|
||||
BaseFactory <|-- ExecutorFactory
|
||||
BaseFactory <|-- ConfigFactory
|
||||
BaseExecutor <|-- NoneExecutor
|
||||
|
|
@ -989,7 +981,7 @@ classDiagram
|
|||
TrainContext o-- BaseExecutor
|
||||
KvcacheView o-- Storage
|
||||
SamplingPipeline o-- BaseSamplingStrategy
|
||||
BaseDataset o-- BaseStorage
|
||||
BaseDataset o-- Store
|
||||
|
||||
%% --- Dependency (uses temporarily) ---
|
||||
TrainConfig ..> BaseStrategy : selects
|
||||
|
|
@ -1003,8 +995,9 @@ classDiagram
|
|||
FFNFactory ..> DeepSeekMoE : creates
|
||||
DecoderBlock ..> AttnFactory : uses
|
||||
DecoderBlock ..> FFNFactory : uses
|
||||
StorageFactory ..> H5Storage : creates
|
||||
StorageFactory ..> JSONStorage : creates
|
||||
StoreFactory ..> H5Store : creates
|
||||
StoreFactory ..> JSONStore : creates
|
||||
StoreFactory ..> MmapStore : creates
|
||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||
ConfigFactory ..> EncoderConfig : creates
|
||||
ExecutorFactory ..> NoneExecutor : creates
|
||||
|
|
@ -1037,7 +1030,6 @@ classDiagram
|
|||
Executor --> AutoModel
|
||||
Executor --> AutoTokenizer
|
||||
TaskManager --> AutoTokenizer
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -1047,7 +1039,7 @@ classDiagram
|
|||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
|
|
@ -1061,7 +1053,7 @@ classDiagram
|
|||
|
||||
| Pattern | Classes | Purpose |
|
||||
|---------|---------|---------|
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||
|
|
@ -1071,7 +1063,7 @@ classDiagram
|
|||
| **Context** | `TrainContext` | Unified training state bag |
|
||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||
|
||||
|
|
@ -1083,7 +1075,7 @@ classDiagram
|
|||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
|
|
|
|||
|
|
@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors.
|
|||
## Overview
|
||||
|
||||
```
|
||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
|
||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
|
||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||
|
||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||
|
||||
```
|
||||
StorageFactory.create("h5") → H5Storage
|
||||
StorageFactory.create("json") → JSONStorage
|
||||
StoreFactory.create("h5") → H5Store
|
||||
StoreFactory.create("json") → JSONStore
|
||||
StoreFactory.create("bin") → MmapStore
|
||||
```
|
||||
|
||||
Both support shared memory via `.share_memory_()`.
|
||||
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||
|
||||
## Data Keys by Training Type
|
||||
|
||||
|
|
@ -34,8 +35,8 @@ Both support shared memory via `.share_memory_()`.
|
|||
|
||||
```
|
||||
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
||||
→ StorageFactory.create(detect_format(path))
|
||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||
→ StoreFactory.create(detect_format(path))
|
||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ sliding window [begin, end) via get_index(idx)
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,36 +1,20 @@
|
|||
"""Storage backends for different data formats.
|
||||
|
||||
Design
|
||||
------
|
||||
|
||||
Three-layer architecture:
|
||||
|
||||
1. **I/O layer** — ``save_*`` / ``load_*`` functions that read/write raw files
|
||||
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
|
||||
These are format-specific, low-level helpers — no abstraction, no state.
|
||||
|
||||
2. **Store (ABC)** — the central abstraction. Each concrete ``Store`` calls the
|
||||
I/O layer during ``load()``, then **normalizes** multi-segment data into a
|
||||
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
|
||||
is just a vanilla slice — no ``bisect``, no segment bookkeeping.
|
||||
|
||||
Data format inside a ``Store``::
|
||||
|
||||
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
|
||||
self._length = N # min first-dim size across keys, O(1)
|
||||
|
||||
3. **Dataset layer** — ``BaseDataset`` owns a ``Store`` and only calls
|
||||
``store.fetch(begin, end, key)``. It never knows whether the data came
|
||||
from HDF5, JSON, or mmap.
|
||||
Layers:
|
||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
|
||||
return Dict[str, List[Tensor]] — format-specific, no state
|
||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||
Dict[str, List[Tensor]] per key via _normalize(),
|
||||
fetch() uses bisect across segments — no forced concat
|
||||
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||
|
||||
Key properties:
|
||||
|
||||
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
|
||||
``__len__`` (O(1)). No hidden computation inside a fetcher.
|
||||
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
|
||||
Multiple DataLoader workers share the same OS page-cache pages.
|
||||
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
|
||||
time, so fetch-time logic is trivial.
|
||||
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||
datasets larger than RAM
|
||||
- Explicit length: _length = min(total elements across keys), set at load,
|
||||
__len__ returns O(1)
|
||||
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||
workers share OS page-cache pages
|
||||
"""
|
||||
|
||||
import bisect
|
||||
|
|
|
|||
Loading…
Reference in New Issue