refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储
- 删除所有 def 函数 -> None 返回值类型标注 - FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤 - 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载 - 新增 save_bin/load_bin/json_to_bin 工具函数 - detect_format 支持 bin 格式自动检测
This commit is contained in:
parent
2d5dc93b3d
commit
cb8dcb97ea
|
|
@ -8,11 +8,15 @@ from astrai.dataset.storage import (
|
|||
BaseStorage,
|
||||
H5Storage,
|
||||
JSONStorage,
|
||||
MmapStorage,
|
||||
MultiSegmentFetcher,
|
||||
StorageFactory,
|
||||
detect_format,
|
||||
json_to_bin,
|
||||
load_bin,
|
||||
load_h5,
|
||||
load_json,
|
||||
save_bin,
|
||||
save_h5,
|
||||
save_json,
|
||||
)
|
||||
|
|
@ -25,11 +29,15 @@ __all__ = [
|
|||
"BaseStorage",
|
||||
"H5Storage",
|
||||
"JSONStorage",
|
||||
"MmapStorage",
|
||||
"StorageFactory",
|
||||
"detect_format",
|
||||
"save_h5",
|
||||
"load_h5",
|
||||
"save_json",
|
||||
"load_json",
|
||||
"save_bin",
|
||||
"load_bin",
|
||||
"json_to_bin",
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, dataset_cls: type) -> None:
|
||||
def _validate_component(cls, dataset_cls: type):
|
||||
"""Validate that the dataset class inherits from BaseDataset."""
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
|||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
|
@ -104,6 +105,38 @@ def load_json(
|
|||
return tensor_group
|
||||
|
||||
|
||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
meta = {}
|
||||
for key, tensors in tensor_group.items():
|
||||
cat = torch.cat(tensors, dim=0)
|
||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||
save_json(meta, os.path.join(file_path, "meta.json"))
|
||||
|
||||
|
||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||
meta = load_json(os.path.join(file_path, "meta.json"))
|
||||
segments: Dict[str, List[Tensor]] = {}
|
||||
for key, info in meta.items():
|
||||
arr = np.memmap(
|
||||
os.path.join(file_path, f"{key}.bin"),
|
||||
dtype=info["dtype"],
|
||||
mode="r",
|
||||
shape=tuple(info["shape"]),
|
||||
)
|
||||
segments[key] = [torch.from_numpy(arr)]
|
||||
return segments
|
||||
|
||||
|
||||
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
|
||||
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
|
||||
merged = {}
|
||||
for key, tensors in segments.items():
|
||||
merged[key] = [torch.cat(tensors, dim=0)]
|
||||
save_bin(bin_path, merged)
|
||||
|
||||
|
||||
def detect_format(load_path: str) -> str:
|
||||
"""Auto-detect storage format from files in the directory.
|
||||
|
||||
|
|
@ -128,6 +161,9 @@ def detect_format(load_path: str) -> str:
|
|||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||
if h5_files:
|
||||
return "h5"
|
||||
bin_files = list(root.rglob("*.bin"))
|
||||
if bin_files and (root / "meta.json").exists():
|
||||
return "bin"
|
||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||
if json_files:
|
||||
return "json"
|
||||
|
|
@ -227,7 +263,7 @@ class BaseStorage(ABC):
|
|||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||
|
||||
@abstractmethod
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
def load(self, load_path: str, tokenizer=None):
|
||||
"""Load data from the given path into internal fetcher."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -272,7 +308,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, storage_cls: type) -> None:
|
||||
def _validate_component(cls, storage_cls: type):
|
||||
if not issubclass(storage_cls, BaseStorage):
|
||||
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
||||
|
||||
|
|
@ -281,7 +317,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
|
|||
class H5Storage(BaseStorage):
|
||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
def load(self, load_path: str, tokenizer=None):
|
||||
segments = load_h5(load_path)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
|
@ -296,6 +332,26 @@ class JSONStorage(BaseStorage):
|
|||
callable (str -> List[int]) at load time.
|
||||
"""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
def load(self, load_path: str, tokenizer=None):
|
||||
segments = load_json(load_path, tokenizer=tokenizer)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
@StorageFactory.register("bin")
|
||||
class MmapStorage(BaseStorage):
|
||||
"""Memory-mapped binary storage backend.
|
||||
|
||||
Each key is stored as a concatenated raw binary file (.bin) with
|
||||
metadata in meta.json. Loading mmaps the files so each process
|
||||
shares the same physical pages via the OS page cache — no per-process
|
||||
memory duplication.
|
||||
"""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None):
|
||||
self._mmap_refs = []
|
||||
raw = load_bin(load_path)
|
||||
segments = {}
|
||||
for key, tensors in raw.items():
|
||||
self._mmap_refs.extend(tensors)
|
||||
segments[key] = tensors
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class Registry:
|
|||
component_cls: Type,
|
||||
category: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
):
|
||||
"""Register a component class with optional category and priority."""
|
||||
if name in self._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
|
|
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
|
|||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||
def _validate_component(cls, component_cls: Type[T]):
|
||||
"""Validate that the component class is valid for this factory.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class Allocator:
|
|||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||
def free(self, idx: int, keep_cached: bool = False):
|
||||
with self._lock:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
|
|
@ -51,7 +51,7 @@ class Allocator:
|
|||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
def inc_ref(self, idx: int):
|
||||
with self._lock:
|
||||
self._refs[idx] += 1
|
||||
self._lru.pop(idx, None)
|
||||
|
|
@ -60,7 +60,7 @@ class Allocator:
|
|||
with self._lock:
|
||||
return self._refs[idx]
|
||||
|
||||
def touch(self, idx: int) -> None:
|
||||
def touch(self, idx: int):
|
||||
with self._lock:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ class PrefixCache:
|
|||
self._hash_to_page: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def evict(self, idx: int) -> None:
|
||||
def evict(self, idx: int):
|
||||
with self._lock:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
|
|
@ -96,9 +96,7 @@ class PrefixCache:
|
|||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
with self._lock:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
|
|
@ -127,13 +125,13 @@ class PagePool:
|
|||
def alloc(self) -> int:
|
||||
return self._alloc.alloc()
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
def free(self, idx: int):
|
||||
keep = self._prefix.has_page(idx)
|
||||
self._alloc.free(idx, keep_cached=keep)
|
||||
if not keep:
|
||||
self._prefix.evict(idx)
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
def inc_ref(self, idx: int):
|
||||
self._alloc.inc_ref(idx)
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
|
|
@ -142,9 +140,7 @@ class PagePool:
|
|||
self._alloc.touch(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||
|
||||
|
||||
|
|
@ -157,7 +153,7 @@ class TaskTable:
|
|||
self._cached: Dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||
def set(self, task_id: str, page_table: List[int], cached: int):
|
||||
with self._lock:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
|
@ -220,7 +216,7 @@ class Storage:
|
|||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
):
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
|
|
@ -286,7 +282,7 @@ class KvcacheView:
|
|||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
||||
start_pos = self._total_len - k.size(1)
|
||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
|
|
@ -339,7 +335,7 @@ class KVCache:
|
|||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str) -> None:
|
||||
def task_free(self, task_id: str):
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self._pool.free(idx)
|
||||
|
|
@ -359,7 +355,7 @@ class KVCache:
|
|||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
) -> None:
|
||||
):
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
|
|
|
|||
|
|
@ -29,9 +29,7 @@ class Executor:
|
|||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
def execute_prefill(
|
||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||
) -> None:
|
||||
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
||||
if start_pos >= prompt_len:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -75,14 +75,14 @@ class InferenceScheduler:
|
|||
def add_task(self, prompt: str, **kwargs) -> str:
|
||||
return self._task_mgr.add_task(prompt, **kwargs)
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
def remove_task(self, task_id: str):
|
||||
for task in self._task_mgr.remove_task(task_id):
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self._task_mgr.get_stats()
|
||||
|
||||
def _run_generation_loop(self) -> None:
|
||||
def _run_generation_loop(self):
|
||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||
try:
|
||||
while self._running:
|
||||
|
|
@ -186,14 +186,14 @@ class InferenceScheduler:
|
|||
self._task_mgr.clear_queues()
|
||||
raise
|
||||
|
||||
def start(self) -> None:
|
||||
def start(self):
|
||||
if not self._running:
|
||||
self._running = True
|
||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||
t.start()
|
||||
self._loop_thread = t
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self):
|
||||
self._running = False
|
||||
self._task_mgr.wake()
|
||||
if hasattr(self, "_loop_thread"):
|
||||
|
|
|
|||
|
|
@ -172,12 +172,12 @@ class TaskManager:
|
|||
to_add.append(self.waiting_queue.popleft())
|
||||
return to_add
|
||||
|
||||
def activate(self, task: Task) -> None:
|
||||
def activate(self, task: Task):
|
||||
task.status = TaskStatus.RUNNING
|
||||
with self._lock:
|
||||
self.active_tasks.append(task)
|
||||
|
||||
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||
def return_to_waiting(self, tasks: List[Task]):
|
||||
with self._lock:
|
||||
for task in reversed(tasks):
|
||||
self.waiting_queue.appendleft(task)
|
||||
|
|
@ -185,7 +185,7 @@ class TaskManager:
|
|||
def has_work(self) -> bool:
|
||||
return bool(self.active_tasks or self.waiting_queue)
|
||||
|
||||
def wait_for_tasks(self, timeout: float = 1.0) -> None:
|
||||
def wait_for_tasks(self, timeout: float = 1.0):
|
||||
self._task_event.clear()
|
||||
self._task_event.wait(timeout=timeout)
|
||||
|
||||
|
|
@ -197,10 +197,10 @@ class TaskManager:
|
|||
with self._lock:
|
||||
return list(self.waiting_queue)
|
||||
|
||||
def clear_queues(self) -> None:
|
||||
def clear_queues(self):
|
||||
with self._lock:
|
||||
self.waiting_queue.clear()
|
||||
self.active_tasks.clear()
|
||||
|
||||
def wake(self) -> None:
|
||||
def wake(self):
|
||||
self._task_event.set()
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class GenerateResult:
|
|||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def wait_completion(self, timeout: float = 300.0) -> None:
|
||||
def wait_completion(self, timeout: float = 300.0):
|
||||
with self._cond:
|
||||
if not self._cond.wait_for(
|
||||
lambda: self._completed >= self._total, timeout=timeout
|
||||
|
|
@ -281,7 +281,7 @@ class InferenceEngine:
|
|||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.scheduler.get_stats()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
def shutdown(self):
|
||||
self.scheduler.stop()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, Path],
|
||||
) -> None:
|
||||
):
|
||||
save_model(
|
||||
config=self.config.to_dict(),
|
||||
state_dict=self.state_dict(),
|
||||
|
|
|
|||
|
|
@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor):
|
|||
|
||||
@ExecutorFactory.register("fsdp")
|
||||
class FSDPExecutor(BaseExecutor):
|
||||
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
grad_accum_steps: int = 1,
|
||||
process_group=None,
|
||||
sharding_strategy=None,
|
||||
cpu_offload=None,
|
||||
auto_wrap_policy=None,
|
||||
backward_prefetch=None,
|
||||
mixed_precision=None,
|
||||
ignored_modules=None,
|
||||
param_init_fn=None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = True,
|
||||
use_orig_params: bool = False,
|
||||
ignored_states=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||
self._fsdp_kwargs = fsdp_kwargs
|
||||
self._fsdp_kwargs = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=use_orig_params,
|
||||
ignored_states=ignored_states,
|
||||
device_mesh=device_mesh,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
self._original_model: Optional[nn.Module] = None
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ _CONFIG_FILE = "config.json"
|
|||
_WEIGHTS_FILE = "model.safetensors"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
||||
def save_safetensors(state_dict: dict, path: str | Path):
|
||||
st.save_file(state_dict, str(path))
|
||||
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict:
|
|||
return st.load_file(str(path))
|
||||
|
||||
|
||||
def save_json(data: dict, path: str | Path) -> None:
|
||||
def save_json(data: dict, path: str | Path):
|
||||
with open(str(path), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ def load_json(path: str | Path) -> dict:
|
|||
return json.load(f)
|
||||
|
||||
|
||||
def save_torch(obj: Any, path: str | Path) -> None:
|
||||
def save_torch(obj: Any, path: str | Path):
|
||||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
|||
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str):
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _CONFIG_FILE)
|
||||
|
|
@ -129,7 +129,7 @@ class Checkpoint:
|
|||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def save(self, save_dir: str) -> None:
|
||||
def save(self, save_dir: str):
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, strategy_cls: type) -> None:
|
||||
def _validate_component(cls, strategy_cls: type):
|
||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
|
|
|
|||
Loading…
Reference in New Issue