AstrAI/astrai/preprocessing/writer.py

76 lines
2.1 KiB
Python

"""Storage writer strategies for pipeline output.
The :class:`StoreWriter` abstraction decouples the pipeline from the
concrete storage format (bin / h5). The pipeline builds a ``{key:
List[Tensor]}`` dict and delegates the write to the writer selected
by ``output.storage_format``.
"""
import logging
import os
import shutil
from abc import ABC, abstractmethod
from typing import Dict, List
import torch
from astrai.dataset.storage import save_bin, save_h5
from astrai.factory import BaseFactory
logger = logging.getLogger(__name__)
class StoreWriter(ABC):
"""Write pre-tokenized tensors to disk in a format-specific way."""
@abstractmethod
def save(
self,
output_dir: str,
domain: str,
shard_idx: int,
tensors: Dict[str, List[torch.Tensor]],
) -> None: ...
class StoreWriterFactory(BaseFactory["StoreWriter"]):
pass
@StoreWriterFactory.register("bin")
class BinWriter(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
try:
save_bin(shard_path, tensors)
except Exception:
if os.path.exists(shard_path):
shutil.rmtree(shard_path, ignore_errors=True)
logger.error(
"Failed to write shard %s/%s_%04d, cleaned up partial output",
domain,
"shard",
shard_idx,
exc_info=True,
)
raise
@StoreWriterFactory.register("h5")
class H5Writer(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
chunk_dir = os.path.join(output_dir, domain)
file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5")
try:
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
except Exception:
if os.path.exists(file_path):
os.remove(file_path)
logger.error(
"Failed to write shard %s/data_%04d.h5, cleaned up partial output",
domain,
shard_idx,
exc_info=True,
)
raise