diff --git a/astrai/dataset/__init__.py b/astrai/dataset/__init__.py index 8207577..495b1f9 100644 --- a/astrai/dataset/__init__.py +++ b/astrai/dataset/__init__.py @@ -8,11 +8,15 @@ from astrai.dataset.storage import ( BaseStorage, H5Storage, JSONStorage, + MmapStorage, MultiSegmentFetcher, StorageFactory, detect_format, + json_to_bin, + load_bin, load_h5, load_json, + save_bin, save_h5, save_json, ) @@ -25,11 +29,15 @@ __all__ = [ "BaseStorage", "H5Storage", "JSONStorage", + "MmapStorage", "StorageFactory", "detect_format", "save_h5", "load_h5", "save_json", "load_json", + "save_bin", + "load_bin", + "json_to_bin", "ResumableDistributedSampler", ] diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 29844d2..3fda455 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]): """ @classmethod - def _validate_component(cls, dataset_cls: type) -> None: + def _validate_component(cls, dataset_cls: type): """Validate that the dataset class inherits from BaseDataset.""" if not issubclass(dataset_cls, BaseDataset): raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index 9afb808..1989311 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Callable, Dict, List, Optional, Union import h5py +import numpy as np import torch from torch import Tensor @@ -104,6 +105,38 @@ def load_json( return tensor_group +def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]): + os.makedirs(file_path, exist_ok=True) + meta = {} + for key, tensors in tensor_group.items(): + cat = torch.cat(tensors, dim=0) + meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]} + np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin")) + save_json(meta, os.path.join(file_path, "meta.json")) + + +def load_bin(file_path: str) -> Dict[str, List[Tensor]]: + meta = load_json(os.path.join(file_path, "meta.json")) + segments: Dict[str, List[Tensor]] = {} + for key, info in meta.items(): + arr = np.memmap( + os.path.join(file_path, f"{key}.bin"), + dtype=info["dtype"], + mode="r", + shape=tuple(info["shape"]), + ) + segments[key] = [torch.from_numpy(arr)] + return segments + + +def json_to_bin(json_path: str, bin_path: str, tokenizer=None): + segments = load_json(json_path, share_memory=False, tokenizer=tokenizer) + merged = {} + for key, tensors in segments.items(): + merged[key] = [torch.cat(tensors, dim=0)] + save_bin(bin_path, merged) + + def detect_format(load_path: str) -> str: """Auto-detect storage format from files in the directory. @@ -128,6 +161,9 @@ def detect_format(load_path: str) -> str: h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) if h5_files: return "h5" + bin_files = list(root.rglob("*.bin")) + if bin_files and (root / "meta.json").exists(): + return "bin" json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) if json_files: return "json" @@ -227,7 +263,7 @@ class BaseStorage(ABC): self._fetcher: Optional[MultiSegmentFetcher] = None @abstractmethod - def load(self, load_path: str, tokenizer=None) -> None: + def load(self, load_path: str, tokenizer=None): """Load data from the given path into internal fetcher.""" raise NotImplementedError @@ -272,7 +308,7 @@ class StorageFactory(BaseFactory["BaseStorage"]): """ @classmethod - def _validate_component(cls, storage_cls: type) -> None: + def _validate_component(cls, storage_cls: type): if not issubclass(storage_cls, BaseStorage): raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage") @@ -281,7 +317,7 @@ class StorageFactory(BaseFactory["BaseStorage"]): class H5Storage(BaseStorage): """HDF5-based storage backend (pre-tokenized data).""" - def load(self, load_path: str, tokenizer=None) -> None: + def load(self, load_path: str, tokenizer=None): segments = load_h5(load_path) self._fetcher = MultiSegmentFetcher(segments) @@ -296,6 +332,26 @@ class JSONStorage(BaseStorage): callable (str -> List[int]) at load time. """ - def load(self, load_path: str, tokenizer=None) -> None: + def load(self, load_path: str, tokenizer=None): segments = load_json(load_path, tokenizer=tokenizer) self._fetcher = MultiSegmentFetcher(segments) + + +@StorageFactory.register("bin") +class MmapStorage(BaseStorage): + """Memory-mapped binary storage backend. + + Each key is stored as a concatenated raw binary file (.bin) with + metadata in meta.json. Loading mmaps the files so each process + shares the same physical pages via the OS page cache — no per-process + memory duplication. + """ + + def load(self, load_path: str, tokenizer=None): + self._mmap_refs = [] + raw = load_bin(load_path) + segments = {} + for key, tensors in raw.items(): + self._mmap_refs.extend(tensors) + segments[key] = tensors + self._fetcher = MultiSegmentFetcher(segments) diff --git a/astrai/factory.py b/astrai/factory.py index 1bc3310..f0d8ccc 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -23,7 +23,7 @@ class Registry: component_cls: Type, category: Optional[str] = None, priority: int = 0, - ) -> None: + ): """Register a component class with optional category and priority.""" if name in self._entries: raise ValueError(f"Component '{name}' is already registered") @@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]): return component_cls(*args, **kwargs) @classmethod - def _validate_component(cls, component_cls: Type[T]) -> None: + def _validate_component(cls, component_cls: Type[T]): """Validate that the component class is valid for this factory. Override this method in subclasses to add custom validation. diff --git a/astrai/inference/core/cache.py b/astrai/inference/core/cache.py index a5d707f..0070b35 100644 --- a/astrai/inference/core/cache.py +++ b/astrai/inference/core/cache.py @@ -42,7 +42,7 @@ class Allocator: return idx return -1 - def free(self, idx: int, keep_cached: bool = False) -> None: + def free(self, idx: int, keep_cached: bool = False): with self._lock: self._refs[idx] -= 1 if self._refs[idx] == 0: @@ -51,7 +51,7 @@ class Allocator: else: self._free_mask |= 1 << idx - def inc_ref(self, idx: int) -> None: + def inc_ref(self, idx: int): with self._lock: self._refs[idx] += 1 self._lru.pop(idx, None) @@ -60,7 +60,7 @@ class Allocator: with self._lock: return self._refs[idx] - def touch(self, idx: int) -> None: + def touch(self, idx: int): with self._lock: self._lru.move_to_end(idx) @@ -74,7 +74,7 @@ class PrefixCache: self._hash_to_page: Dict[int, int] = {} self._lock = threading.Lock() - def evict(self, idx: int) -> None: + def evict(self, idx: int): with self._lock: h = self._page_to_hash.pop(idx, None) if h is not None: @@ -96,9 +96,7 @@ class PrefixCache: hits.append(p) return hits - def record( - self, page_idx: int, token_ids: List[int], logical_page_idx: int - ) -> None: + def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int): with self._lock: h = page_hash(token_ids, logical_page_idx, self._page_size) old_h = self._page_to_hash.pop(page_idx, None) @@ -127,13 +125,13 @@ class PagePool: def alloc(self) -> int: return self._alloc.alloc() - def free(self, idx: int) -> None: + def free(self, idx: int): keep = self._prefix.has_page(idx) self._alloc.free(idx, keep_cached=keep) if not keep: self._prefix.evict(idx) - def inc_ref(self, idx: int) -> None: + def inc_ref(self, idx: int): self._alloc.inc_ref(idx) def lookup(self, token_ids: List[int]) -> List[int]: @@ -142,9 +140,7 @@ class PagePool: self._alloc.touch(p) return hits - def record( - self, page_idx: int, token_ids: List[int], logical_page_idx: int - ) -> None: + def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int): self._prefix.record(page_idx, token_ids, logical_page_idx) @@ -157,7 +153,7 @@ class TaskTable: self._cached: Dict[str, int] = {} self._lock = threading.Lock() - def set(self, task_id: str, page_table: List[int], cached: int) -> None: + def set(self, task_id: str, page_table: List[int], cached: int): with self._lock: self._pages[task_id] = page_table self._cached[task_id] = cached @@ -220,7 +216,7 @@ class Storage: start_pos: int, k: Tensor, v: Tensor, - ) -> None: + ): seq_len = k.size(1) if seq_len == 0: return @@ -286,7 +282,7 @@ class KvcacheView: self._page_table = page_table self._total_len = total_len - def write(self, layer_id: int, k: Tensor, v: Tensor) -> None: + def write(self, layer_id: int, k: Tensor, v: Tensor): start_pos = self._total_len - k.size(1) self._storage.write(layer_id, self._page_table, start_pos, k, v) @@ -339,7 +335,7 @@ class KVCache: self._table.set(task_id, hits + new_pages, cached) return True - def task_free(self, task_id: str) -> None: + def task_free(self, task_id: str): page_table, _ = self._table.pop(task_id) for idx in page_table: self._pool.free(idx) @@ -359,7 +355,7 @@ class KVCache: def task_record_hashes( self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 - ) -> None: + ): page_table = self._table.get(task_id) full_pages = len(prompt_ids) // self.page_size for i in range(start_logical_page, full_pages): diff --git a/astrai/inference/core/executor.py b/astrai/inference/core/executor.py index 3eebf81..9157cfe 100644 --- a/astrai/inference/core/executor.py +++ b/astrai/inference/core/executor.py @@ -29,9 +29,7 @@ class Executor: self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype - def execute_prefill( - self, tasks: List[Task], prompt_len: int, start_pos: int = 0 - ) -> None: + def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0): if start_pos >= prompt_len: return diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 4ac63ce..3e76f77 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -75,14 +75,14 @@ class InferenceScheduler: def add_task(self, prompt: str, **kwargs) -> str: return self._task_mgr.add_task(prompt, **kwargs) - def remove_task(self, task_id: str) -> None: + def remove_task(self, task_id: str): for task in self._task_mgr.remove_task(task_id): self._page_cache.task_free(task.task_id) def get_stats(self) -> Dict[str, Any]: return self._task_mgr.get_stats() - def _run_generation_loop(self) -> None: + def _run_generation_loop(self): stop_ids = self._task_mgr.tokenizer.stop_ids try: while self._running: @@ -186,14 +186,14 @@ class InferenceScheduler: self._task_mgr.clear_queues() raise - def start(self) -> None: + def start(self): if not self._running: self._running = True t = threading.Thread(target=self._run_generation_loop, daemon=True) t.start() self._loop_thread = t - def stop(self) -> None: + def stop(self): self._running = False self._task_mgr.wake() if hasattr(self, "_loop_thread"): diff --git a/astrai/inference/core/task.py b/astrai/inference/core/task.py index 40e0da8..1b449c8 100644 --- a/astrai/inference/core/task.py +++ b/astrai/inference/core/task.py @@ -172,12 +172,12 @@ class TaskManager: to_add.append(self.waiting_queue.popleft()) return to_add - def activate(self, task: Task) -> None: + def activate(self, task: Task): task.status = TaskStatus.RUNNING with self._lock: self.active_tasks.append(task) - def return_to_waiting(self, tasks: List[Task]) -> None: + def return_to_waiting(self, tasks: List[Task]): with self._lock: for task in reversed(tasks): self.waiting_queue.appendleft(task) @@ -185,7 +185,7 @@ class TaskManager: def has_work(self) -> bool: return bool(self.active_tasks or self.waiting_queue) - def wait_for_tasks(self, timeout: float = 1.0) -> None: + def wait_for_tasks(self, timeout: float = 1.0): self._task_event.clear() self._task_event.wait(timeout=timeout) @@ -197,10 +197,10 @@ class TaskManager: with self._lock: return list(self.waiting_queue) - def clear_queues(self) -> None: + def clear_queues(self): with self._lock: self.waiting_queue.clear() self.active_tasks.clear() - def wake(self) -> None: + def wake(self): self._task_event.set() diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 2fb0343..63b28ed 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -48,7 +48,7 @@ class GenerateResult: def wait(self, timeout: Optional[float] = None) -> bool: return self._event.wait(timeout=timeout) - def wait_completion(self, timeout: float = 300.0) -> None: + def wait_completion(self, timeout: float = 300.0): with self._cond: if not self._cond.wait_for( lambda: self._completed >= self._total, timeout=timeout @@ -281,7 +281,7 @@ class InferenceEngine: def get_stats(self) -> Dict[str, Any]: return self.scheduler.get_stats() - def shutdown(self) -> None: + def shutdown(self): self.scheduler.stop() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index 650b94a..5459314 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -83,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module): def save_pretrained( self, save_directory: Union[str, Path], - ) -> None: + ): save_model( config=self.config.to_dict(), state_dict=self.state_dict(), diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index ae12f9a..c1f2141 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor): @ExecutorFactory.register("fsdp") class FSDPExecutor(BaseExecutor): - def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs): + def __init__( + self, + grad_accum_steps: int = 1, + process_group=None, + sharding_strategy=None, + cpu_offload=None, + auto_wrap_policy=None, + backward_prefetch=None, + mixed_precision=None, + ignored_modules=None, + param_init_fn=None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = True, + use_orig_params: bool = False, + ignored_states=None, + device_mesh=None, + ): super().__init__(grad_accum_steps=grad_accum_steps) - self._fsdp_kwargs = fsdp_kwargs + self._fsdp_kwargs = { + k: v + for k, v in dict( + process_group=process_group, + sharding_strategy=sharding_strategy, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + ignored_modules=ignored_modules, + param_init_fn=param_init_fn, + sync_module_states=sync_module_states, + forward_prefetch=forward_prefetch, + limit_all_gathers=limit_all_gathers, + use_orig_params=use_orig_params, + ignored_states=ignored_states, + device_mesh=device_mesh, + ).items() + if v is not None + } self._original_model: Optional[nn.Module] = None def _prepare_model(self, model: nn.Module) -> nn.Module: diff --git a/astrai/serialization.py b/astrai/serialization.py index a21d55b..2243d28 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -16,7 +16,7 @@ _CONFIG_FILE = "config.json" _WEIGHTS_FILE = "model.safetensors" -def save_safetensors(state_dict: dict, path: str | Path) -> None: +def save_safetensors(state_dict: dict, path: str | Path): st.save_file(state_dict, str(path)) @@ -24,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict: return st.load_file(str(path)) -def save_json(data: dict, path: str | Path) -> None: +def save_json(data: dict, path: str | Path): with open(str(path), "w") as f: json.dump(data, f, indent=2) @@ -34,7 +34,7 @@ def load_json(path: str | Path) -> dict: return json.load(f) -def save_torch(obj: Any, path: str | Path) -> None: +def save_torch(obj: Any, path: str | Path): torch.save(obj, str(path)) @@ -64,7 +64,7 @@ def load_torch(path: str | Path, broadcast: bool = False) -> Any: return torch.load(buf, map_location="cpu", weights_only=False) -def save_model(config: dict, state_dict: dict, save_directory: str) -> None: +def save_model(config: dict, state_dict: dict, save_directory: str): save_path = Path(save_directory) save_path.mkdir(parents=True, exist_ok=True) save_json(config, save_path / _CONFIG_FILE) @@ -129,7 +129,7 @@ class Checkpoint: extra: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict) - def save(self, save_dir: str) -> None: + def save(self, save_dir: str): save_path = Path(save_dir) save_path.mkdir(parents=True, exist_ok=True) diff --git a/astrai/trainer/schedule.py b/astrai/trainer/schedule.py index 9727bf0..f4810ab 100644 --- a/astrai/trainer/schedule.py +++ b/astrai/trainer/schedule.py @@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]): """ @classmethod - def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None: + def _validate_component(cls, scheduler_cls: Type[BaseScheduler]): """Validate that the scheduler class inherits from BaseScheduler.""" if not issubclass(scheduler_cls, BaseScheduler): raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index e340691..37ee0a8 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]): """ @classmethod - def _validate_component(cls, strategy_cls: type) -> None: + def _validate_component(cls, strategy_cls: type): """Validate that the strategy class inherits from BaseStrategy.""" if not issubclass(strategy_cls, BaseStrategy): raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")