docs : 更新架构文档与 storage 注释,同步 Store 重构
- architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore) - architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系 - dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构 - storage.py: module docstring 改用缩进式注释风格
This commit is contained in:
parent
6e150ea6d0
commit
0a708fff24
|
|
@ -107,7 +107,7 @@ classDiagram
|
||||||
class BaseDataset {
|
class BaseDataset {
|
||||||
+int window_size
|
+int window_size
|
||||||
+int stride
|
+int stride
|
||||||
+Optional[BaseStorage] storage
|
+Optional[Store] storage
|
||||||
+load(load_path, storage_type, tokenizer)
|
+load(load_path, storage_type, tokenizer)
|
||||||
+__getitem__(index)
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
|
|
@ -129,38 +129,29 @@ classDiagram
|
||||||
+__getitem__(index) Dict
|
+__getitem__(index) Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseSegmentFetcher {
|
class Store {
|
||||||
+List[Tensor] segments
|
+Dict[str, List[Tensor]] _data
|
||||||
+List[int] cum_lengths
|
+Dict[str, List[int]] _cum
|
||||||
+int total_length
|
+int _length
|
||||||
+fetch_data(begin_idx, end_idx) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseStorage {
|
|
||||||
+MultiSegmentFetcher _fetcher
|
|
||||||
+keys (property)
|
+keys (property)
|
||||||
+load(load_path, tokenizer)
|
+load(path, tokenizer)
|
||||||
+fetch(begin, end, keys)
|
+fetch(begin, end, keys)
|
||||||
+__len__()
|
+__len__()
|
||||||
|
-_fetch_key(key, begin, end) Tensor
|
||||||
|
-_normalize(raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
class H5Storage {
|
class H5Store {
|
||||||
+load(load_path, tokenizer)
|
+load(path, tokenizer)
|
||||||
+fetch(begin, end, keys) Dict
|
|
||||||
+keys() List
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class JSONStorage {
|
class JSONStore {
|
||||||
+load(load_path, tokenizer)
|
+load(path, tokenizer)
|
||||||
+fetch(begin, end, keys) Dict
|
|
||||||
+keys() List
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class MultiSegmentFetcher {
|
class MmapStore {
|
||||||
+Dict multi_fetchers
|
+List _mmap_refs
|
||||||
+List multi_keys
|
+load(path, tokenizer)
|
||||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
|
||||||
+fetch_data(begin_idx, end_idx) Dict
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
|
|
@ -168,10 +159,10 @@ classDiagram
|
||||||
+int iter
|
+int iter
|
||||||
}
|
}
|
||||||
|
|
||||||
class StorageFactory {
|
class StoreFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(storage_type) BaseStorage
|
+create(storage_type) Store
|
||||||
}
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
|
|
@ -925,8 +916,9 @@ classDiagram
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
BaseDataset <|-- GRPODataset
|
BaseDataset <|-- GRPODataset
|
||||||
BaseStorage <|-- H5Storage
|
Store <|-- H5Store
|
||||||
BaseStorage <|-- JSONStorage
|
Store <|-- JSONStore
|
||||||
|
Store <|-- MmapStore
|
||||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
|
|
@ -946,7 +938,7 @@ classDiagram
|
||||||
BaseFactory <|-- StrategyFactory
|
BaseFactory <|-- StrategyFactory
|
||||||
BaseFactory <|-- SchedulerFactory
|
BaseFactory <|-- SchedulerFactory
|
||||||
BaseFactory <|-- CallbackFactory
|
BaseFactory <|-- CallbackFactory
|
||||||
BaseFactory <|-- StorageFactory
|
BaseFactory <|-- StoreFactory
|
||||||
BaseFactory <|-- ExecutorFactory
|
BaseFactory <|-- ExecutorFactory
|
||||||
BaseFactory <|-- ConfigFactory
|
BaseFactory <|-- ConfigFactory
|
||||||
BaseExecutor <|-- NoneExecutor
|
BaseExecutor <|-- NoneExecutor
|
||||||
|
|
@ -989,7 +981,7 @@ classDiagram
|
||||||
TrainContext o-- BaseExecutor
|
TrainContext o-- BaseExecutor
|
||||||
KvcacheView o-- Storage
|
KvcacheView o-- Storage
|
||||||
SamplingPipeline o-- BaseSamplingStrategy
|
SamplingPipeline o-- BaseSamplingStrategy
|
||||||
BaseDataset o-- BaseStorage
|
BaseDataset o-- Store
|
||||||
|
|
||||||
%% --- Dependency (uses temporarily) ---
|
%% --- Dependency (uses temporarily) ---
|
||||||
TrainConfig ..> BaseStrategy : selects
|
TrainConfig ..> BaseStrategy : selects
|
||||||
|
|
@ -1003,8 +995,9 @@ classDiagram
|
||||||
FFNFactory ..> DeepSeekMoE : creates
|
FFNFactory ..> DeepSeekMoE : creates
|
||||||
DecoderBlock ..> AttnFactory : uses
|
DecoderBlock ..> AttnFactory : uses
|
||||||
DecoderBlock ..> FFNFactory : uses
|
DecoderBlock ..> FFNFactory : uses
|
||||||
StorageFactory ..> H5Storage : creates
|
StoreFactory ..> H5Store : creates
|
||||||
StorageFactory ..> JSONStorage : creates
|
StoreFactory ..> JSONStore : creates
|
||||||
|
StoreFactory ..> MmapStore : creates
|
||||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||||
ConfigFactory ..> EncoderConfig : creates
|
ConfigFactory ..> EncoderConfig : creates
|
||||||
ExecutorFactory ..> NoneExecutor : creates
|
ExecutorFactory ..> NoneExecutor : creates
|
||||||
|
|
@ -1037,7 +1030,6 @@ classDiagram
|
||||||
Executor --> AutoModel
|
Executor --> AutoModel
|
||||||
Executor --> AutoTokenizer
|
Executor --> AutoTokenizer
|
||||||
TaskManager --> AutoTokenizer
|
TaskManager --> AutoTokenizer
|
||||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -1047,7 +1039,7 @@ classDiagram
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
|
|
@ -1061,7 +1053,7 @@ classDiagram
|
||||||
|
|
||||||
| Pattern | Classes | Purpose |
|
| Pattern | Classes | Purpose |
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||||
|
|
@ -1071,7 +1063,7 @@ classDiagram
|
||||||
| **Context** | `TrainContext` | Unified training state bag |
|
| **Context** | `TrainContext` | Unified training state bag |
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
||||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||||
|
|
||||||
|
|
@ -1083,7 +1075,7 @@ classDiagram
|
||||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
|
|
|
||||||
|
|
@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors.
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
```
|
```
|
||||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
|
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
||||||
```
|
```
|
||||||
|
|
||||||
## Data Preparation
|
## Data Preparation
|
||||||
|
|
||||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
|
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||||
|
|
||||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
|
|
||||||
```
|
```
|
||||||
StorageFactory.create("h5") → H5Storage
|
StoreFactory.create("h5") → H5Store
|
||||||
StorageFactory.create("json") → JSONStorage
|
StoreFactory.create("json") → JSONStore
|
||||||
|
StoreFactory.create("bin") → MmapStore
|
||||||
```
|
```
|
||||||
|
|
||||||
Both support shared memory via `.share_memory_()`.
|
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||||
|
|
||||||
## Data Keys by Training Type
|
## Data Keys by Training Type
|
||||||
|
|
||||||
|
|
@ -34,8 +35,8 @@ Both support shared memory via `.share_memory_()`.
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
||||||
→ StorageFactory.create(detect_format(path))
|
→ StoreFactory.create(detect_format(path))
|
||||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||||
→ BaseDataset.__getitem__(idx)
|
→ BaseDataset.__getitem__(idx)
|
||||||
→ sliding window [begin, end) via get_index(idx)
|
→ sliding window [begin, end) via get_index(idx)
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,20 @@
|
||||||
"""Storage backends for different data formats.
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
Design
|
Layers:
|
||||||
------
|
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
|
||||||
|
return Dict[str, List[Tensor]] — format-specific, no state
|
||||||
Three-layer architecture:
|
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||||
|
Dict[str, List[Tensor]] per key via _normalize(),
|
||||||
1. **I/O layer** — ``save_*`` / ``load_*`` functions that read/write raw files
|
fetch() uses bisect across segments — no forced concat
|
||||||
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
|
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||||
These are format-specific, low-level helpers — no abstraction, no state.
|
|
||||||
|
|
||||||
2. **Store (ABC)** — the central abstraction. Each concrete ``Store`` calls the
|
|
||||||
I/O layer during ``load()``, then **normalizes** multi-segment data into a
|
|
||||||
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
|
|
||||||
is just a vanilla slice — no ``bisect``, no segment bookkeeping.
|
|
||||||
|
|
||||||
Data format inside a ``Store``::
|
|
||||||
|
|
||||||
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
|
|
||||||
self._length = N # min first-dim size across keys, O(1)
|
|
||||||
|
|
||||||
3. **Dataset layer** — ``BaseDataset`` owns a ``Store`` and only calls
|
|
||||||
``store.fetch(begin, end, key)``. It never knows whether the data came
|
|
||||||
from HDF5, JSON, or mmap.
|
|
||||||
|
|
||||||
Key properties:
|
Key properties:
|
||||||
|
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||||
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
|
datasets larger than RAM
|
||||||
``__len__`` (O(1)). No hidden computation inside a fetcher.
|
- Explicit length: _length = min(total elements across keys), set at load,
|
||||||
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
|
__len__ returns O(1)
|
||||||
Multiple DataLoader workers share the same OS page-cache pages.
|
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||||
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
|
workers share OS page-cache pages
|
||||||
time, so fetch-time logic is trivial.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue