refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储

- 删除所有 def 函数 -> None 返回值类型标注
- FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤
- 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载
- 新增 save_bin/load_bin/json_to_bin 工具函数
- detect_format 支持 bin 格式自动检测
This commit is contained in:
ViperEkura 2026-05-28 13:57:06 +08:00
parent 2d5dc93b3d
commit cb8dcb97ea
14 changed files with 142 additions and 48 deletions

View File

@ -8,11 +8,15 @@ from astrai.dataset.storage import (
BaseStorage, BaseStorage,
H5Storage, H5Storage,
JSONStorage, JSONStorage,
MmapStorage,
MultiSegmentFetcher, MultiSegmentFetcher,
StorageFactory, StorageFactory,
detect_format, detect_format,
json_to_bin,
load_bin,
load_h5, load_h5,
load_json, load_json,
save_bin,
save_h5, save_h5,
save_json, save_json,
) )
@ -25,11 +29,15 @@ __all__ = [
"BaseStorage", "BaseStorage",
"H5Storage", "H5Storage",
"JSONStorage", "JSONStorage",
"MmapStorage",
"StorageFactory", "StorageFactory",
"detect_format", "detect_format",
"save_h5", "save_h5",
"load_h5", "load_h5",
"save_json", "save_json",
"load_json", "load_json",
"save_bin",
"load_bin",
"json_to_bin",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
""" """
@classmethod @classmethod
def _validate_component(cls, dataset_cls: type) -> None: def _validate_component(cls, dataset_cls: type):
"""Validate that the dataset class inherits from BaseDataset.""" """Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset): if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")

View File

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import h5py import h5py
import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
@ -104,6 +105,38 @@ def load_json(
return tensor_group return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json"))
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json"))
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str: def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory. """Auto-detect storage format from files in the directory.
@ -128,6 +161,9 @@ def detect_format(load_path: str) -> str:
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files: if h5_files:
return "h5" return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files: if json_files:
return "json" return "json"
@ -227,7 +263,7 @@ class BaseStorage(ABC):
self._fetcher: Optional[MultiSegmentFetcher] = None self._fetcher: Optional[MultiSegmentFetcher] = None
@abstractmethod @abstractmethod
def load(self, load_path: str, tokenizer=None) -> None: def load(self, load_path: str, tokenizer=None):
"""Load data from the given path into internal fetcher.""" """Load data from the given path into internal fetcher."""
raise NotImplementedError raise NotImplementedError
@ -272,7 +308,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
""" """
@classmethod @classmethod
def _validate_component(cls, storage_cls: type) -> None: def _validate_component(cls, storage_cls: type):
if not issubclass(storage_cls, BaseStorage): if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage") raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@ -281,7 +317,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
class H5Storage(BaseStorage): class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data).""" """HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None: def load(self, load_path: str, tokenizer=None):
segments = load_h5(load_path) segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments) self._fetcher = MultiSegmentFetcher(segments)
@ -296,6 +332,26 @@ class JSONStorage(BaseStorage):
callable (str -> List[int]) at load time. callable (str -> List[int]) at load time.
""" """
def load(self, load_path: str, tokenizer=None) -> None: def load(self, load_path: str, tokenizer=None):
segments = load_json(load_path, tokenizer=tokenizer) segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments) self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("bin")
class MmapStorage(BaseStorage):
"""Memory-mapped binary storage backend.
Each key is stored as a concatenated raw binary file (.bin) with
metadata in meta.json. Loading mmaps the files so each process
shares the same physical pages via the OS page cache no per-process
memory duplication.
"""
def load(self, load_path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(load_path)
segments = {}
for key, tensors in raw.items():
self._mmap_refs.extend(tensors)
segments[key] = tensors
self._fetcher = MultiSegmentFetcher(segments)

View File

@ -23,7 +23,7 @@ class Registry:
component_cls: Type, component_cls: Type,
category: Optional[str] = None, category: Optional[str] = None,
priority: int = 0, priority: int = 0,
) -> None: ):
"""Register a component class with optional category and priority.""" """Register a component class with optional category and priority."""
if name in self._entries: if name in self._entries:
raise ValueError(f"Component '{name}' is already registered") raise ValueError(f"Component '{name}' is already registered")
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
return component_cls(*args, **kwargs) return component_cls(*args, **kwargs)
@classmethod @classmethod
def _validate_component(cls, component_cls: Type[T]) -> None: def _validate_component(cls, component_cls: Type[T]):
"""Validate that the component class is valid for this factory. """Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation. Override this method in subclasses to add custom validation.

View File

@ -42,7 +42,7 @@ class Allocator:
return idx return idx
return -1 return -1
def free(self, idx: int, keep_cached: bool = False) -> None: def free(self, idx: int, keep_cached: bool = False):
with self._lock: with self._lock:
self._refs[idx] -= 1 self._refs[idx] -= 1
if self._refs[idx] == 0: if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else: else:
self._free_mask |= 1 << idx self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None: def inc_ref(self, idx: int):
with self._lock: with self._lock:
self._refs[idx] += 1 self._refs[idx] += 1
self._lru.pop(idx, None) self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock: with self._lock:
return self._refs[idx] return self._refs[idx]
def touch(self, idx: int) -> None: def touch(self, idx: int):
with self._lock: with self._lock:
self._lru.move_to_end(idx) self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def evict(self, idx: int) -> None: def evict(self, idx: int):
with self._lock: with self._lock:
h = self._page_to_hash.pop(idx, None) h = self._page_to_hash.pop(idx, None)
if h is not None: if h is not None:
@ -96,9 +96,7 @@ class PrefixCache:
hits.append(p) hits.append(p)
return hits return hits
def record( def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock: with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size) h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None) old_h = self._page_to_hash.pop(page_idx, None)
@ -127,13 +125,13 @@ class PagePool:
def alloc(self) -> int: def alloc(self) -> int:
return self._alloc.alloc() return self._alloc.alloc()
def free(self, idx: int) -> None: def free(self, idx: int):
keep = self._prefix.has_page(idx) keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep) self._alloc.free(idx, keep_cached=keep)
if not keep: if not keep:
self._prefix.evict(idx) self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None: def inc_ref(self, idx: int):
self._alloc.inc_ref(idx) self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]: def lookup(self, token_ids: List[int]) -> List[int]:
@ -142,9 +140,7 @@ class PagePool:
self._alloc.touch(p) self._alloc.touch(p)
return hits return hits
def record( def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx) self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -157,7 +153,7 @@ class TaskTable:
self._cached: Dict[str, int] = {} self._cached: Dict[str, int] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None: def set(self, task_id: str, page_table: List[int], cached: int):
with self._lock: with self._lock:
self._pages[task_id] = page_table self._pages[task_id] = page_table
self._cached[task_id] = cached self._cached[task_id] = cached
@ -220,7 +216,7 @@ class Storage:
start_pos: int, start_pos: int,
k: Tensor, k: Tensor,
v: Tensor, v: Tensor,
) -> None: ):
seq_len = k.size(1) seq_len = k.size(1)
if seq_len == 0: if seq_len == 0:
return return
@ -286,7 +282,7 @@ class KvcacheView:
self._page_table = page_table self._page_table = page_table
self._total_len = total_len self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None: def write(self, layer_id: int, k: Tensor, v: Tensor):
start_pos = self._total_len - k.size(1) start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v) self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -339,7 +335,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached) self._table.set(task_id, hits + new_pages, cached)
return True return True
def task_free(self, task_id: str) -> None: def task_free(self, task_id: str):
page_table, _ = self._table.pop(task_id) page_table, _ = self._table.pop(task_id)
for idx in page_table: for idx in page_table:
self._pool.free(idx) self._pool.free(idx)
@ -359,7 +355,7 @@ class KVCache:
def task_record_hashes( def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None: ):
page_table = self._table.get(task_id) page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages): for i in range(start_logical_page, full_pages):

View File

@ -29,9 +29,7 @@ class Executor:
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill( def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len: if start_pos >= prompt_len:
return return

View File

@ -75,14 +75,14 @@ class InferenceScheduler:
def add_task(self, prompt: str, **kwargs) -> str: def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str):
for task in self._task_mgr.remove_task(task_id): for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats() return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None: def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids stop_ids = self._task_mgr.tokenizer.stop_ids
try: try:
while self._running: while self._running:
@ -186,14 +186,14 @@ class InferenceScheduler:
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise raise
def start(self) -> None: def start(self):
if not self._running: if not self._running:
self._running = True self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True) t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start() t.start()
self._loop_thread = t self._loop_thread = t
def stop(self) -> None: def stop(self):
self._running = False self._running = False
self._task_mgr.wake() self._task_mgr.wake()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft()) to_add.append(self.waiting_queue.popleft())
return to_add return to_add
def activate(self, task: Task) -> None: def activate(self, task: Task):
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
with self._lock: with self._lock:
self.active_tasks.append(task) self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]) -> None: def return_to_waiting(self, tasks: List[Task]):
with self._lock: with self._lock:
for task in reversed(tasks): for task in reversed(tasks):
self.waiting_queue.appendleft(task) self.waiting_queue.appendleft(task)
@ -185,7 +185,7 @@ class TaskManager:
def has_work(self) -> bool: def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue) return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0) -> None: def wait_for_tasks(self, timeout: float = 1.0):
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)
@ -197,10 +197,10 @@ class TaskManager:
with self._lock: with self._lock:
return list(self.waiting_queue) return list(self.waiting_queue)
def clear_queues(self) -> None: def clear_queues(self):
with self._lock: with self._lock:
self.waiting_queue.clear() self.waiting_queue.clear()
self.active_tasks.clear() self.active_tasks.clear()
def wake(self) -> None: def wake(self):
self._task_event.set() self._task_event.set()

View File

@ -48,7 +48,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool: def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0) -> None: def wait_completion(self, timeout: float = 300.0):
with self._cond: with self._cond:
if not self._cond.wait_for( if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout lambda: self._completed >= self._total, timeout=timeout
@ -281,7 +281,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self):
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -83,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, Path], save_directory: Union[str, Path],
) -> None: ):
save_model( save_model(
config=self.config.to_dict(), config=self.config.to_dict(),
state_dict=self.state_dict(), state_dict=self.state_dict(),

View File

@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor):
@ExecutorFactory.register("fsdp") @ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor): class FSDPExecutor(BaseExecutor):
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs): def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps) super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = fsdp_kwargs self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module: def _prepare_model(self, model: nn.Module) -> nn.Module:

View File

@ -16,7 +16,7 @@ _CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors" _WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: str | Path) -> None: def save_safetensors(state_dict: dict, path: str | Path):
st.save_file(state_dict, str(path)) st.save_file(state_dict, str(path))
@ -24,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path)) return st.load_file(str(path))
def save_json(data: dict, path: str | Path) -> None: def save_json(data: dict, path: str | Path):
with open(str(path), "w") as f: with open(str(path), "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
@ -34,7 +34,7 @@ def load_json(path: str | Path) -> dict:
return json.load(f) return json.load(f)
def save_torch(obj: Any, path: str | Path) -> None: def save_torch(obj: Any, path: str | Path):
torch.save(obj, str(path)) torch.save(obj, str(path))
@ -64,7 +64,7 @@ def load_torch(path: str | Path, broadcast: bool = False) -> Any:
return torch.load(buf, map_location="cpu", weights_only=False) return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str) -> None: def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory) save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE) save_json(config, save_path / _CONFIG_FILE)
@ -129,7 +129,7 @@ class Checkpoint:
extra: Dict[str, Any] = field(default_factory=dict) extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str) -> None: def save(self, save_dir: str):
save_path = Path(save_dir) save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
""" """
@classmethod @classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None: def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
"""Validate that the scheduler class inherits from BaseScheduler.""" """Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler): if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
""" """
@classmethod @classmethod
def _validate_component(cls, strategy_cls: type) -> None: def _validate_component(cls, strategy_cls: type):
"""Validate that the strategy class inherits from BaseStrategy.""" """Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy): if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")