76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
"""Storage writer strategies for pipeline output.
|
|
|
|
The :class:`StoreWriter` abstraction decouples the pipeline from the
|
|
concrete storage format (bin / h5). The pipeline builds a ``{key:
|
|
List[Tensor]}`` dict and delegates the write to the writer selected
|
|
by ``output.storage_format``.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import shutil
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
from astrai.dataset.storage import save_bin, save_h5
|
|
from astrai.factory import BaseFactory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StoreWriter(ABC):
|
|
"""Write pre-tokenized tensors to disk in a format-specific way."""
|
|
|
|
@abstractmethod
|
|
def save(
|
|
self,
|
|
output_dir: str,
|
|
domain: str,
|
|
shard_idx: int,
|
|
tensors: Dict[str, List[torch.Tensor]],
|
|
) -> None: ...
|
|
|
|
|
|
class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
|
pass
|
|
|
|
|
|
@StoreWriterFactory.register("bin")
|
|
class BinWriter(StoreWriter):
|
|
def save(self, output_dir, domain, shard_idx, tensors):
|
|
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
|
|
try:
|
|
save_bin(shard_path, tensors)
|
|
except Exception:
|
|
if os.path.exists(shard_path):
|
|
shutil.rmtree(shard_path, ignore_errors=True)
|
|
logger.error(
|
|
"Failed to write shard %s/%s_%04d, cleaned up partial output",
|
|
domain,
|
|
"shard",
|
|
shard_idx,
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
|
|
@StoreWriterFactory.register("h5")
|
|
class H5Writer(StoreWriter):
|
|
def save(self, output_dir, domain, shard_idx, tensors):
|
|
chunk_dir = os.path.join(output_dir, domain)
|
|
file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5")
|
|
try:
|
|
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
|
|
except Exception:
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
logger.error(
|
|
"Failed to write shard %s/data_%04d.h5, cleaned up partial output",
|
|
domain,
|
|
shard_idx,
|
|
exc_info=True,
|
|
)
|
|
raise
|