refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储

- 删除所有 def 函数 -> None 返回值类型标注
- FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤
- 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载
- 新增 save_bin/load_bin/json_to_bin 工具函数
- detect_format 支持 bin 格式自动检测
This commit is contained in:
ViperEkura 2026-05-28 13:57:06 +08:00
parent 2d5dc93b3d
commit cb8dcb97ea
14 changed files with 142 additions and 48 deletions

View File

@ -8,11 +8,15 @@ from astrai.dataset.storage import (
BaseStorage,
H5Storage,
JSONStorage,
MmapStorage,
MultiSegmentFetcher,
StorageFactory,
detect_format,
json_to_bin,
load_bin,
load_h5,
load_json,
save_bin,
save_h5,
save_json,
)
@ -25,11 +29,15 @@ __all__ = [
"BaseStorage",
"H5Storage",
"JSONStorage",
"MmapStorage",
"StorageFactory",
"detect_format",
"save_h5",
"load_h5",
"save_json",
"load_json",
"save_bin",
"load_bin",
"json_to_bin",
"ResumableDistributedSampler",
]

View File

@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
"""
@classmethod
def _validate_component(cls, dataset_cls: type) -> None:
def _validate_component(cls, dataset_cls: type):
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")

View File

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import h5py
import numpy as np
import torch
from torch import Tensor
@ -104,6 +105,38 @@ def load_json(
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json"))
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json"))
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
@ -128,6 +161,9 @@ def detect_format(load_path: str) -> str:
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
@ -227,7 +263,7 @@ class BaseStorage(ABC):
self._fetcher: Optional[MultiSegmentFetcher] = None
@abstractmethod
def load(self, load_path: str, tokenizer=None) -> None:
def load(self, load_path: str, tokenizer=None):
"""Load data from the given path into internal fetcher."""
raise NotImplementedError
@ -272,7 +308,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
"""
@classmethod
def _validate_component(cls, storage_cls: type) -> None:
def _validate_component(cls, storage_cls: type):
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@ -281,7 +317,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None:
def load(self, load_path: str, tokenizer=None):
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
@ -296,6 +332,26 @@ class JSONStorage(BaseStorage):
callable (str -> List[int]) at load time.
"""
def load(self, load_path: str, tokenizer=None) -> None:
def load(self, load_path: str, tokenizer=None):
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("bin")
class MmapStorage(BaseStorage):
"""Memory-mapped binary storage backend.
Each key is stored as a concatenated raw binary file (.bin) with
metadata in meta.json. Loading mmaps the files so each process
shares the same physical pages via the OS page cache no per-process
memory duplication.
"""
def load(self, load_path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(load_path)
segments = {}
for key, tensors in raw.items():
self._mmap_refs.extend(tensors)
segments[key] = tensors
self._fetcher = MultiSegmentFetcher(segments)

View File

@ -23,7 +23,7 @@ class Registry:
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
) -> None:
):
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]) -> None:
def _validate_component(cls, component_cls: Type[T]):
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.

View File

@ -42,7 +42,7 @@ class Allocator:
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
def free(self, idx: int, keep_cached: bool = False):
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
def inc_ref(self, idx: int):
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock:
return self._refs[idx]
def touch(self, idx: int) -> None:
def touch(self, idx: int):
with self._lock:
self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int) -> None:
def evict(self, idx: int):
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
@ -96,9 +96,7 @@ class PrefixCache:
hits.append(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
@ -127,13 +125,13 @@ class PagePool:
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int) -> None:
def free(self, idx: int):
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None:
def inc_ref(self, idx: int):
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
@ -142,9 +140,7 @@ class PagePool:
self._alloc.touch(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -157,7 +153,7 @@ class TaskTable:
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
def set(self, task_id: str, page_table: List[int], cached: int):
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
@ -220,7 +216,7 @@ class Storage:
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
):
seq_len = k.size(1)
if seq_len == 0:
return
@ -286,7 +282,7 @@ class KvcacheView:
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
def write(self, layer_id: int, k: Tensor, v: Tensor):
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -339,7 +335,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
def task_free(self, task_id: str):
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
@ -359,7 +355,7 @@ class KVCache:
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
):
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):

View File

@ -29,9 +29,7 @@ class Executor:
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
if start_pos >= prompt_len:
return

View File

@ -75,14 +75,14 @@ class InferenceScheduler:
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None:
def remove_task(self, task_id: str):
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
@ -186,14 +186,14 @@ class InferenceScheduler:
self._task_mgr.clear_queues()
raise
def start(self) -> None:
def start(self):
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self) -> None:
def stop(self):
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft())
return to_add
def activate(self, task: Task) -> None:
def activate(self, task: Task):
task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]) -> None:
def return_to_waiting(self, tasks: List[Task]):
with self._lock:
for task in reversed(tasks):
self.waiting_queue.appendleft(task)
@ -185,7 +185,7 @@ class TaskManager:
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0) -> None:
def wait_for_tasks(self, timeout: float = 1.0):
self._task_event.clear()
self._task_event.wait(timeout=timeout)
@ -197,10 +197,10 @@ class TaskManager:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self) -> None:
def clear_queues(self):
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self) -> None:
def wake(self):
self._task_event.set()

View File

@ -48,7 +48,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0) -> None:
def wait_completion(self, timeout: float = 300.0):
with self._cond:
if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
@ -281,7 +281,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats()
def shutdown(self) -> None:
def shutdown(self):
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -83,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained(
self,
save_directory: Union[str, Path],
) -> None:
):
save_model(
config=self.config.to_dict(),
state_dict=self.state_dict(),

View File

@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor):
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = fsdp_kwargs
self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:

View File

@ -16,7 +16,7 @@ _CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: str | Path) -> None:
def save_safetensors(state_dict: dict, path: str | Path):
st.save_file(state_dict, str(path))
@ -24,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path))
def save_json(data: dict, path: str | Path) -> None:
def save_json(data: dict, path: str | Path):
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
@ -34,7 +34,7 @@ def load_json(path: str | Path) -> dict:
return json.load(f)
def save_torch(obj: Any, path: str | Path) -> None:
def save_torch(obj: Any, path: str | Path):
torch.save(obj, str(path))
@ -64,7 +64,7 @@ def load_torch(path: str | Path, broadcast: bool = False) -> Any:
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
@ -129,7 +129,7 @@ class Checkpoint:
extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str) -> None:
def save(self, save_dir: str):
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
"""
@classmethod
def _validate_component(cls, strategy_cls: type) -> None:
def _validate_component(cls, strategy_cls: type):
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")