refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储
- 删除所有 def 函数 -> None 返回值类型标注 - FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤 - 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载 - 新增 save_bin/load_bin/json_to_bin 工具函数 - detect_format 支持 bin 格式自动检测
This commit is contained in:
parent
2d5dc93b3d
commit
cb8dcb97ea
|
|
@ -8,11 +8,15 @@ from astrai.dataset.storage import (
|
||||||
BaseStorage,
|
BaseStorage,
|
||||||
H5Storage,
|
H5Storage,
|
||||||
JSONStorage,
|
JSONStorage,
|
||||||
|
MmapStorage,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
StorageFactory,
|
StorageFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
|
json_to_bin,
|
||||||
|
load_bin,
|
||||||
load_h5,
|
load_h5,
|
||||||
load_json,
|
load_json,
|
||||||
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
save_json,
|
save_json,
|
||||||
)
|
)
|
||||||
|
|
@ -25,11 +29,15 @@ __all__ = [
|
||||||
"BaseStorage",
|
"BaseStorage",
|
||||||
"H5Storage",
|
"H5Storage",
|
||||||
"JSONStorage",
|
"JSONStorage",
|
||||||
|
"MmapStorage",
|
||||||
"StorageFactory",
|
"StorageFactory",
|
||||||
"detect_format",
|
"detect_format",
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
"save_json",
|
"save_json",
|
||||||
"load_json",
|
"load_json",
|
||||||
|
"save_bin",
|
||||||
|
"load_bin",
|
||||||
|
"json_to_bin",
|
||||||
"ResumableDistributedSampler",
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, dataset_cls: type) -> None:
|
def _validate_component(cls, dataset_cls: type):
|
||||||
"""Validate that the dataset class inherits from BaseDataset."""
|
"""Validate that the dataset class inherits from BaseDataset."""
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
@ -104,6 +105,38 @@ def load_json(
|
||||||
return tensor_group
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
|
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
meta = {}
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
cat = torch.cat(tensors, dim=0)
|
||||||
|
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||||
|
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||||
|
save_json(meta, os.path.join(file_path, "meta.json"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||||
|
meta = load_json(os.path.join(file_path, "meta.json"))
|
||||||
|
segments: Dict[str, List[Tensor]] = {}
|
||||||
|
for key, info in meta.items():
|
||||||
|
arr = np.memmap(
|
||||||
|
os.path.join(file_path, f"{key}.bin"),
|
||||||
|
dtype=info["dtype"],
|
||||||
|
mode="r",
|
||||||
|
shape=tuple(info["shape"]),
|
||||||
|
)
|
||||||
|
segments[key] = [torch.from_numpy(arr)]
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
|
||||||
|
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
|
||||||
|
merged = {}
|
||||||
|
for key, tensors in segments.items():
|
||||||
|
merged[key] = [torch.cat(tensors, dim=0)]
|
||||||
|
save_bin(bin_path, merged)
|
||||||
|
|
||||||
|
|
||||||
def detect_format(load_path: str) -> str:
|
def detect_format(load_path: str) -> str:
|
||||||
"""Auto-detect storage format from files in the directory.
|
"""Auto-detect storage format from files in the directory.
|
||||||
|
|
||||||
|
|
@ -128,6 +161,9 @@ def detect_format(load_path: str) -> str:
|
||||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||||
if h5_files:
|
if h5_files:
|
||||||
return "h5"
|
return "h5"
|
||||||
|
bin_files = list(root.rglob("*.bin"))
|
||||||
|
if bin_files and (root / "meta.json").exists():
|
||||||
|
return "bin"
|
||||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||||
if json_files:
|
if json_files:
|
||||||
return "json"
|
return "json"
|
||||||
|
|
@ -227,7 +263,7 @@ class BaseStorage(ABC):
|
||||||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, load_path: str, tokenizer=None):
|
||||||
"""Load data from the given path into internal fetcher."""
|
"""Load data from the given path into internal fetcher."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -272,7 +308,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, storage_cls: type) -> None:
|
def _validate_component(cls, storage_cls: type):
|
||||||
if not issubclass(storage_cls, BaseStorage):
|
if not issubclass(storage_cls, BaseStorage):
|
||||||
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
||||||
|
|
||||||
|
|
@ -281,7 +317,7 @@ class StorageFactory(BaseFactory["BaseStorage"]):
|
||||||
class H5Storage(BaseStorage):
|
class H5Storage(BaseStorage):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, load_path: str, tokenizer=None):
|
||||||
segments = load_h5(load_path)
|
segments = load_h5(load_path)
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
@ -296,6 +332,26 @@ class JSONStorage(BaseStorage):
|
||||||
callable (str -> List[int]) at load time.
|
callable (str -> List[int]) at load time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, load_path: str, tokenizer=None):
|
||||||
segments = load_json(load_path, tokenizer=tokenizer)
|
segments = load_json(load_path, tokenizer=tokenizer)
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
|
@StorageFactory.register("bin")
|
||||||
|
class MmapStorage(BaseStorage):
|
||||||
|
"""Memory-mapped binary storage backend.
|
||||||
|
|
||||||
|
Each key is stored as a concatenated raw binary file (.bin) with
|
||||||
|
metadata in meta.json. Loading mmaps the files so each process
|
||||||
|
shares the same physical pages via the OS page cache — no per-process
|
||||||
|
memory duplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load(self, load_path: str, tokenizer=None):
|
||||||
|
self._mmap_refs = []
|
||||||
|
raw = load_bin(load_path)
|
||||||
|
segments = {}
|
||||||
|
for key, tensors in raw.items():
|
||||||
|
self._mmap_refs.extend(tensors)
|
||||||
|
segments[key] = tensors
|
||||||
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class Registry:
|
||||||
component_cls: Type,
|
component_cls: Type,
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> None:
|
):
|
||||||
"""Register a component class with optional category and priority."""
|
"""Register a component class with optional category and priority."""
|
||||||
if name in self._entries:
|
if name in self._entries:
|
||||||
raise ValueError(f"Component '{name}' is already registered")
|
raise ValueError(f"Component '{name}' is already registered")
|
||||||
|
|
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
return component_cls(*args, **kwargs)
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
def _validate_component(cls, component_cls: Type[T]):
|
||||||
"""Validate that the component class is valid for this factory.
|
"""Validate that the component class is valid for this factory.
|
||||||
|
|
||||||
Override this method in subclasses to add custom validation.
|
Override this method in subclasses to add custom validation.
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class Allocator:
|
||||||
return idx
|
return idx
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
def free(self, idx: int, keep_cached: bool = False):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._refs[idx] -= 1
|
self._refs[idx] -= 1
|
||||||
if self._refs[idx] == 0:
|
if self._refs[idx] == 0:
|
||||||
|
|
@ -51,7 +51,7 @@ class Allocator:
|
||||||
else:
|
else:
|
||||||
self._free_mask |= 1 << idx
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
def inc_ref(self, idx: int) -> None:
|
def inc_ref(self, idx: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._refs[idx] += 1
|
self._refs[idx] += 1
|
||||||
self._lru.pop(idx, None)
|
self._lru.pop(idx, None)
|
||||||
|
|
@ -60,7 +60,7 @@ class Allocator:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._refs[idx]
|
return self._refs[idx]
|
||||||
|
|
||||||
def touch(self, idx: int) -> None:
|
def touch(self, idx: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._lru.move_to_end(idx)
|
self._lru.move_to_end(idx)
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ class PrefixCache:
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def evict(self, idx: int) -> None:
|
def evict(self, idx: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
h = self._page_to_hash.pop(idx, None)
|
h = self._page_to_hash.pop(idx, None)
|
||||||
if h is not None:
|
if h is not None:
|
||||||
|
|
@ -96,9 +96,7 @@ class PrefixCache:
|
||||||
hits.append(p)
|
hits.append(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def record(
|
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
|
||||||
) -> None:
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||||
old_h = self._page_to_hash.pop(page_idx, None)
|
old_h = self._page_to_hash.pop(page_idx, None)
|
||||||
|
|
@ -127,13 +125,13 @@ class PagePool:
|
||||||
def alloc(self) -> int:
|
def alloc(self) -> int:
|
||||||
return self._alloc.alloc()
|
return self._alloc.alloc()
|
||||||
|
|
||||||
def free(self, idx: int) -> None:
|
def free(self, idx: int):
|
||||||
keep = self._prefix.has_page(idx)
|
keep = self._prefix.has_page(idx)
|
||||||
self._alloc.free(idx, keep_cached=keep)
|
self._alloc.free(idx, keep_cached=keep)
|
||||||
if not keep:
|
if not keep:
|
||||||
self._prefix.evict(idx)
|
self._prefix.evict(idx)
|
||||||
|
|
||||||
def inc_ref(self, idx: int) -> None:
|
def inc_ref(self, idx: int):
|
||||||
self._alloc.inc_ref(idx)
|
self._alloc.inc_ref(idx)
|
||||||
|
|
||||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||||
|
|
@ -142,9 +140,7 @@ class PagePool:
|
||||||
self._alloc.touch(p)
|
self._alloc.touch(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def record(
|
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
|
||||||
) -> None:
|
|
||||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -157,7 +153,7 @@ class TaskTable:
|
||||||
self._cached: Dict[str, int] = {}
|
self._cached: Dict[str, int] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
def set(self, task_id: str, page_table: List[int], cached: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._pages[task_id] = page_table
|
self._pages[task_id] = page_table
|
||||||
self._cached[task_id] = cached
|
self._cached[task_id] = cached
|
||||||
|
|
@ -220,7 +216,7 @@ class Storage:
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
k: Tensor,
|
k: Tensor,
|
||||||
v: Tensor,
|
v: Tensor,
|
||||||
) -> None:
|
):
|
||||||
seq_len = k.size(1)
|
seq_len = k.size(1)
|
||||||
if seq_len == 0:
|
if seq_len == 0:
|
||||||
return
|
return
|
||||||
|
|
@ -286,7 +282,7 @@ class KvcacheView:
|
||||||
self._page_table = page_table
|
self._page_table = page_table
|
||||||
self._total_len = total_len
|
self._total_len = total_len
|
||||||
|
|
||||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
||||||
start_pos = self._total_len - k.size(1)
|
start_pos = self._total_len - k.size(1)
|
||||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
|
|
@ -339,7 +335,7 @@ class KVCache:
|
||||||
self._table.set(task_id, hits + new_pages, cached)
|
self._table.set(task_id, hits + new_pages, cached)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def task_free(self, task_id: str) -> None:
|
def task_free(self, task_id: str):
|
||||||
page_table, _ = self._table.pop(task_id)
|
page_table, _ = self._table.pop(task_id)
|
||||||
for idx in page_table:
|
for idx in page_table:
|
||||||
self._pool.free(idx)
|
self._pool.free(idx)
|
||||||
|
|
@ -359,7 +355,7 @@ class KVCache:
|
||||||
|
|
||||||
def task_record_hashes(
|
def task_record_hashes(
|
||||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||||
) -> None:
|
):
|
||||||
page_table = self._table.get(task_id)
|
page_table = self._table.get(task_id)
|
||||||
full_pages = len(prompt_ids) // self.page_size
|
full_pages = len(prompt_ids) // self.page_size
|
||||||
for i in range(start_logical_page, full_pages):
|
for i in range(start_logical_page, full_pages):
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,7 @@ class Executor:
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
def execute_prefill(
|
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
||||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
|
||||||
) -> None:
|
|
||||||
if start_pos >= prompt_len:
|
if start_pos >= prompt_len:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,14 +75,14 @@ class InferenceScheduler:
|
||||||
def add_task(self, prompt: str, **kwargs) -> str:
|
def add_task(self, prompt: str, **kwargs) -> str:
|
||||||
return self._task_mgr.add_task(prompt, **kwargs)
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> None:
|
def remove_task(self, task_id: str):
|
||||||
for task in self._task_mgr.remove_task(task_id):
|
for task in self._task_mgr.remove_task(task_id):
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return self._task_mgr.get_stats()
|
return self._task_mgr.get_stats()
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self):
|
||||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
|
|
@ -186,14 +186,14 @@ class InferenceScheduler:
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self):
|
||||||
if not self._running:
|
if not self._running:
|
||||||
self._running = True
|
self._running = True
|
||||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
self._loop_thread = t
|
self._loop_thread = t
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self):
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task_mgr.wake()
|
self._task_mgr.wake()
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
|
|
|
||||||
|
|
@ -172,12 +172,12 @@ class TaskManager:
|
||||||
to_add.append(self.waiting_queue.popleft())
|
to_add.append(self.waiting_queue.popleft())
|
||||||
return to_add
|
return to_add
|
||||||
|
|
||||||
def activate(self, task: Task) -> None:
|
def activate(self, task: Task):
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
def return_to_waiting(self, tasks: List[Task]) -> None:
|
def return_to_waiting(self, tasks: List[Task]):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for task in reversed(tasks):
|
for task in reversed(tasks):
|
||||||
self.waiting_queue.appendleft(task)
|
self.waiting_queue.appendleft(task)
|
||||||
|
|
@ -185,7 +185,7 @@ class TaskManager:
|
||||||
def has_work(self) -> bool:
|
def has_work(self) -> bool:
|
||||||
return bool(self.active_tasks or self.waiting_queue)
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
||||||
def wait_for_tasks(self, timeout: float = 1.0) -> None:
|
def wait_for_tasks(self, timeout: float = 1.0):
|
||||||
self._task_event.clear()
|
self._task_event.clear()
|
||||||
self._task_event.wait(timeout=timeout)
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
|
@ -197,10 +197,10 @@ class TaskManager:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return list(self.waiting_queue)
|
return list(self.waiting_queue)
|
||||||
|
|
||||||
def clear_queues(self) -> None:
|
def clear_queues(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue.clear()
|
self.waiting_queue.clear()
|
||||||
self.active_tasks.clear()
|
self.active_tasks.clear()
|
||||||
|
|
||||||
def wake(self) -> None:
|
def wake(self):
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class GenerateResult:
|
||||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
return self._event.wait(timeout=timeout)
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
def wait_completion(self, timeout: float = 300.0) -> None:
|
def wait_completion(self, timeout: float = 300.0):
|
||||||
with self._cond:
|
with self._cond:
|
||||||
if not self._cond.wait_for(
|
if not self._cond.wait_for(
|
||||||
lambda: self._completed >= self._total, timeout=timeout
|
lambda: self._completed >= self._total, timeout=timeout
|
||||||
|
|
@ -281,7 +281,7 @@ class InferenceEngine:
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return self.scheduler.get_stats()
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self):
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, Path],
|
save_directory: Union[str, Path],
|
||||||
) -> None:
|
):
|
||||||
save_model(
|
save_model(
|
||||||
config=self.config.to_dict(),
|
config=self.config.to_dict(),
|
||||||
state_dict=self.state_dict(),
|
state_dict=self.state_dict(),
|
||||||
|
|
|
||||||
|
|
@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor):
|
||||||
|
|
||||||
@ExecutorFactory.register("fsdp")
|
@ExecutorFactory.register("fsdp")
|
||||||
class FSDPExecutor(BaseExecutor):
|
class FSDPExecutor(BaseExecutor):
|
||||||
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
grad_accum_steps: int = 1,
|
||||||
|
process_group=None,
|
||||||
|
sharding_strategy=None,
|
||||||
|
cpu_offload=None,
|
||||||
|
auto_wrap_policy=None,
|
||||||
|
backward_prefetch=None,
|
||||||
|
mixed_precision=None,
|
||||||
|
ignored_modules=None,
|
||||||
|
param_init_fn=None,
|
||||||
|
sync_module_states: bool = False,
|
||||||
|
forward_prefetch: bool = False,
|
||||||
|
limit_all_gathers: bool = True,
|
||||||
|
use_orig_params: bool = False,
|
||||||
|
ignored_states=None,
|
||||||
|
device_mesh=None,
|
||||||
|
):
|
||||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||||
self._fsdp_kwargs = fsdp_kwargs
|
self._fsdp_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in dict(
|
||||||
|
process_group=process_group,
|
||||||
|
sharding_strategy=sharding_strategy,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
auto_wrap_policy=auto_wrap_policy,
|
||||||
|
backward_prefetch=backward_prefetch,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
ignored_modules=ignored_modules,
|
||||||
|
param_init_fn=param_init_fn,
|
||||||
|
sync_module_states=sync_module_states,
|
||||||
|
forward_prefetch=forward_prefetch,
|
||||||
|
limit_all_gathers=limit_all_gathers,
|
||||||
|
use_orig_params=use_orig_params,
|
||||||
|
ignored_states=ignored_states,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
).items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
self._original_model: Optional[nn.Module] = None
|
self._original_model: Optional[nn.Module] = None
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ _CONFIG_FILE = "config.json"
|
||||||
_WEIGHTS_FILE = "model.safetensors"
|
_WEIGHTS_FILE = "model.safetensors"
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
def save_safetensors(state_dict: dict, path: str | Path):
|
||||||
st.save_file(state_dict, str(path))
|
st.save_file(state_dict, str(path))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict:
|
||||||
return st.load_file(str(path))
|
return st.load_file(str(path))
|
||||||
|
|
||||||
|
|
||||||
def save_json(data: dict, path: str | Path) -> None:
|
def save_json(data: dict, path: str | Path):
|
||||||
with open(str(path), "w") as f:
|
with open(str(path), "w") as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
@ -34,7 +34,7 @@ def load_json(path: str | Path) -> dict:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def save_torch(obj: Any, path: str | Path) -> None:
|
def save_torch(obj: Any, path: str | Path):
|
||||||
torch.save(obj, str(path))
|
torch.save(obj, str(path))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -64,7 +64,7 @@ def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
||||||
return torch.load(buf, map_location="cpu", weights_only=False)
|
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
def save_model(config: dict, state_dict: dict, save_directory: str):
|
||||||
save_path = Path(save_directory)
|
save_path = Path(save_directory)
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
save_json(config, save_path / _CONFIG_FILE)
|
save_json(config, save_path / _CONFIG_FILE)
|
||||||
|
|
@ -129,7 +129,7 @@ class Checkpoint:
|
||||||
extra: Dict[str, Any] = field(default_factory=dict)
|
extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def save(self, save_dir: str) -> None:
|
def save(self, save_dir: str):
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, strategy_cls: type) -> None:
|
def _validate_component(cls, strategy_cls: type):
|
||||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue