refactor : pipeline 策略化拆分,消除 _flush if/else

- PackingStrategy / PositionIdStrategy / StoreWriter 独立文件 + Factory
- Pipeline._flush 零 if/else,纯编排
- SectionRenderer 从 SectionedMaskBuilder 分离
- OutputConfig.position_ids_mode 默认改为 ""none""
This commit is contained in:
ViperEkura 2026-06-06 00:45:33 +08:00
parent 3057741de9
commit 31bc7f5c2a
7 changed files with 424 additions and 306 deletions

View File

@ -87,7 +87,7 @@ class OutputConfig(BaseConfig):
position_ids_mode : Optional[str]
How to compute position_ids in packed sequences.
- ``None`` / ``"none"``: do not generate (backward compatible).
- ``"none"``: do not generate (default).
- ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
"""
@ -96,7 +96,7 @@ class OutputConfig(BaseConfig):
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
position_ids_mode: Optional[str] = None
position_ids_mode: str = "none"
@dataclass

View File

@ -3,12 +3,30 @@ from astrai.preprocessing.builder import (
MaskBuilderFactory,
SectionedMaskBuilder,
)
from astrai.preprocessing.packing import (
PackingStrategy,
PackingStrategyFactory,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from astrai.preprocessing.position_id import (
PositionIdStrategy,
PositionIdStrategyFactory,
)
from astrai.preprocessing.writer import (
StoreWriter,
StoreWriterFactory,
)
__all__ = [
"BaseMaskBuilder",
"MaskBuilderFactory",
"SectionedMaskBuilder",
"PackingStrategy",
"PackingStrategyFactory",
"Pipeline",
"PositionIdStrategy",
"PositionIdStrategyFactory",
"SectionedMaskBuilder",
"StoreWriter",
"StoreWriterFactory",
"filter_by_length",
]

View File

@ -1,8 +1,8 @@
"""Mask building strategies for preprocessing pipeline.
"""Mask building for preprocessing pipeline.
The single :class:`SectionedMaskBuilder` handles all input formats
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
for single-output or ``input.sources`` for multi-output.
:class:`SectionRenderer` converts section specs into token ids and loss
masks (template / text / value extraction). :class:`SectionedMaskBuilder`
orchestrates single-output / multi-output (DPO / GRPO) assembly.
"""
from abc import ABC, abstractmethod
@ -11,27 +11,6 @@ from typing import Optional
from astrai.factory import BaseFactory
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
Returns ``None`` to skip the item entirely.
"""
...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
@ -40,142 +19,15 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
def _resolve_action(action: str, role: str, config) -> str:
"""Resolve action to "train" or "mask".
- ``"train"`` / ``"mask"`` literal
- ``"$role"`` look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
"""
if action == "$role":
return config.mask.get(role, config.mask_default)
return action
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder supporting single and multi-output modes.
class SectionRenderer:
"""Render section specs into ``(ids, loss_mask)`` tuples."""
Single-output (backward-compatible)::
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
{"sequence": [...], "loss_mask": [...], "domain": "..."}
Multi-output (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [
{"field": "chosen", "action": "$role", "template": true}
]},
"rejected": {"sections": [
{"field": "rejected", "action": "$role", "template": true}
]}
}}}
{"chosen": [...], "chosen_mask": [...],
"rejected": [...], "rejected_mask": [...], "domain": "..."}
Output spec fields::
sections list of section specs (same format as single-output)
list_field True when the JSONL field holds a list of values to
tokenise individually and concatenate (GRPO responses)
mask_key explicit output key for the loss mask
(default: ``"{output_key}_mask"``)
dtype explicit tensor dtype for this output key
(default: "int32")
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
sources_spec = getattr(config.input, "sources", None)
if sources_spec:
return self._build_multi(item, sources_spec, config, tokenizer)
return self._build_single(item, config, tokenizer)
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
sections = config.input.sections
if not sections:
return None
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
return None
result: dict = {
"sequence": ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in mask):
result["loss_mask"] = mask
return result
def _build_multi(
self, item: dict, sources_spec: dict, config, tokenizer
) -> Optional[dict]:
result: dict = {}
any_output = False
for output_key, spec in sources_spec.items():
sections = spec.get("sections", [])
if not sections:
continue
if self._is_value_section(sections):
ids = self._extract_raw_value(item, sections)
if ids is None:
continue
result[output_key] = ids
any_output = True
continue
list_field = spec.get("list_field", False)
mask_key = spec.get("mask_key", f"{output_key}_mask")
if list_field:
ids, mask = self._process_list_field(item, sections, config, tokenizer)
else:
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
continue
result[output_key] = ids
if not all(m == 1 for m in mask):
result[mask_key] = mask
elif "mask_key" in spec:
result[mask_key] = mask
any_output = True
if not any_output:
return None
result["domain"] = _extract_domain(item, config.output.domain_key)
return result
@staticmethod
def _is_value_section(sections: list) -> bool:
return len(sections) == 1 and sections[0].get("action") == "value"
@staticmethod
def _extract_raw_value(item: dict, sections: list):
"""Extract a raw value from a JSONL field without tokenisation.
Used for GRPO rewards where the field contains float values.
"""
sec = sections[0]
field = sec["field"]
raw = item.get(field)
if raw is None:
return None
if isinstance(raw, list):
return [float(v) for v in raw]
return [float(raw)]
def _process_sections(
def process_sections(
self,
item: dict,
sections: list,
@ -184,10 +36,6 @@ class SectionedMaskBuilder(BaseMaskBuilder):
*,
is_top_level: bool = False,
):
"""Process a list of sections into ``(ids, loss_mask)``.
Returns ``(None, None)`` if the item should be skipped.
"""
all_ids: list[int] = []
loss_mask: list[int] = []
@ -210,13 +58,13 @@ class SectionedMaskBuilder(BaseMaskBuilder):
)
if use_template:
success = self._append_template_section(
success = self._append_template(
item, field, action, tokenizer, config, all_ids, loss_mask
)
if not success:
continue
else:
success = self._append_text_section(
success = self._append_text(
item,
field,
action,
@ -244,7 +92,70 @@ class SectionedMaskBuilder(BaseMaskBuilder):
return all_ids, loss_mask
def _append_template_section(
def process_list_field(self, item: dict, sections: list, config, tokenizer):
all_ids: list[int] = []
loss_mask: list[int] = []
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
values = item.get(field)
if not isinstance(values, list):
continue
for val in values:
if use_template:
if isinstance(val, list):
wrapper = {field: val}
self._append_template(
wrapper,
field,
action,
tokenizer,
config,
all_ids,
loss_mask,
)
else:
wrapper = {field: str(val)}
self._append_text(
wrapper,
field,
action,
tokenizer,
False,
False,
config,
all_ids,
loss_mask,
)
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
return all_ids, loss_mask
@staticmethod
def is_value_section(sections: list) -> bool:
return len(sections) == 1 and sections[0].get("action") == "value"
@staticmethod
def extract_raw_value(item: dict, sections: list):
sec = sections[0]
field = sec["field"]
raw = item.get(field)
if raw is None:
return None
if isinstance(raw, list):
return [float(v) for v in raw]
return [float(raw)]
def _append_template(
self, item, field, action, tokenizer, config, all_ids, loss_mask
):
messages = item.get(field)
@ -262,7 +173,7 @@ class SectionedMaskBuilder(BaseMaskBuilder):
loss_mask.extend([val] * len(ids))
return True
def _append_text_section(
def _append_text(
self,
item,
field,
@ -289,50 +200,121 @@ class SectionedMaskBuilder(BaseMaskBuilder):
loss_mask.extend([val] * len(ids))
return True
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
all_ids: list[int] = []
loss_mask: list[int] = []
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
values = item.get(field)
if not isinstance(values, list):
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]: ...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder supporting single and multi-output modes.
Single-output::
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
{"sequence": [...], "loss_mask": [...], "domain": "..."}
Multi-output (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [{"field": "chosen", "action": "$role", "template": true}]},
"rejected": {"sections": [{"field": "rejected", "action": "$role", "template": true}]},
}}}
{"chosen": [...], "chosen_mask": [...], "rejected": [...], "rejected_mask": [...], "domain": "..."}
Output spec fields::
sections list of section specs (same format as single-output)
list_field True when JSONL field holds a list (GRPO responses)
mask_key explicit loss-mask output key (default: ``"{output_key}_mask"``)
"""
def __init__(self):
self.renderer = SectionRenderer()
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
sources_spec = getattr(config.input, "sources", None)
if sources_spec:
return self._build_multi(item, sources_spec, config, tokenizer)
return self._build_single(item, config, tokenizer)
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
sections = config.input.sections
if not sections:
return None
ids, mask = self.renderer.process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
return None
result: dict = {
"sequence": ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in mask):
result["loss_mask"] = mask
return result
def _build_multi(
self, item: dict, sources_spec: dict, config, tokenizer
) -> Optional[dict]:
result: dict = {}
any_output = False
for output_key, spec in sources_spec.items():
sections = spec.get("sections", [])
if not sections:
continue
for val in values:
if use_template:
if isinstance(val, list):
wrapper = {field: val}
self._append_template_section(
wrapper,
field,
action,
tokenizer,
config,
all_ids,
loss_mask,
if self.renderer.is_value_section(sections):
ids = self.renderer.extract_raw_value(item, sections)
if ids is None:
continue
result[output_key] = ids
any_output = True
continue
list_field = spec.get("list_field", False)
mask_key = spec.get("mask_key", f"{output_key}_mask")
if list_field:
ids, mask = self.renderer.process_list_field(
item, sections, config, tokenizer
)
else:
wrapper = {field: str(val)}
self._append_text_section(
wrapper,
field,
action,
tokenizer,
False,
False,
config,
all_ids,
loss_mask,
ids, mask = self.renderer.process_sections(
item, sections, config, tokenizer, is_top_level=True
)
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if ids is None:
continue
if not all_ids:
return None, None
return all_ids, loss_mask
result[output_key] = ids
if not all(m == 1 for m in mask):
result[mask_key] = mask
elif "mask_key" in spec:
result[mask_key] = mask
any_output = True
if not any_output:
return None
result["domain"] = _extract_domain(item, config.output.domain_key)
return result

View File

@ -0,0 +1,95 @@
"""Sequence packing strategies for shard-level reordering and truncation.
Each strategy receives the accumulated ``{key: [list of token lists]}``
dict for a shard and returns a reordered / truncated version. The
pipeline later flattens the result into contiguous tensors.
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Tuple
from astrai.factory import BaseFactory
def _truncate(seq: list, max_len: int, mode: str) -> list:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
class PackingStrategy(ABC):
"""Reorder and truncate sequences within a shard."""
@abstractmethod
def apply(
self,
keys: Dict[str, List[list]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[list]]: ...
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, PackingStrategy):
raise TypeError(
f"{component_cls.__name__} must inherit from PackingStrategy"
)
@PackingStrategyFactory.register("simple")
class SimplePacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode):
return {
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
for k, vals in keys.items()
}
@PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode):
sequences = keys.get("sequence", [])
if not sequences:
return keys
plan = self._plan(sequences, max_packed_len)
reordered: dict = defaultdict(list)
for orig_idx, _ in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
)
return dict(reordered)
@staticmethod
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan

View File

@ -1,21 +1,25 @@
"""Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
sharding and flush to ``.h5`` / ``.bin`` storage.
sharding and flush to ``.h5`` / ``.bin`` storage. Packing, position-id
generation and storage writing are each delegated to pluggable strategies,
dispatched by configuration keys.
"""
import json
import os
from collections import defaultdict
from itertools import chain
from typing import List, Optional, Tuple
from typing import Dict, List, Optional
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import SectionedMaskBuilder
from astrai.preprocessing.builder import MaskBuilderFactory
from astrai.preprocessing.packing import PackingStrategyFactory
from astrai.preprocessing.position_id import PositionIdStrategyFactory
from astrai.preprocessing.writer import StoreWriterFactory
from astrai.tokenize import AutoTokenizer
_STR_TO_DTYPE: dict[str, torch.dtype] = {
@ -35,65 +39,6 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
return min_len <= len(text) <= max_len
def _truncate(seq: list, max_len: int, mode: str) -> list:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
def pack_sequences(
sequences: List[list],
max_packed_len: int,
strategy: str,
truncation_mode: str,
) -> List[Tuple[int, int]]:
"""Pack *sequences* into bins and return a reorder plan.
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
All keys (sequence, loss_mask, ) must be reordered and truncated
identically according to this plan.
Supported *strategy* values:
- ``"simple"``: sequential, no reordering.
- ``"bfd"``: best-fit decreasing bin packing.
"""
n = len(sequences)
if strategy == "simple":
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
@ -116,7 +61,14 @@ class Pipeline:
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = SectionedMaskBuilder()
self.mask_builder = MaskBuilderFactory.create("sectioned")
self._packer = PackingStrategyFactory.create(
config.preprocessing.packing_strategy
)
self._position_id = PositionIdStrategyFactory.create(
config.output.position_ids_mode
)
self._writer = StoreWriterFactory.create(config.output.storage_format)
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
@ -168,8 +120,6 @@ class Pipeline:
if total_tokens > 0:
self._flush(domains, shard_idx)
print(f"Done. {count} documents tokenized.")
@staticmethod
def _primary_ids(result: dict) -> list:
"""Return the first list-valued entry in *result* as the primary id
@ -203,27 +153,11 @@ class Pipeline:
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
chunk_dir = os.path.join(self.output_dir, domain)
pp = self.config.preprocessing
if pp.packing_strategy != "simple" and "sequence" in keys:
plan = pack_sequences(
keys["sequence"],
pp.max_packed_len,
pp.packing_strategy,
pp.truncation_mode,
)
reordered = defaultdict(list)
for orig_idx, truncated_len in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
)
)
keys = reordered
keys = self._packer.apply(dict(keys), pp.max_packed_len, pp.truncation_mode)
tensors = {}
tensors: Dict[str, List[torch.Tensor]] = {}
for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get(
self.config.output.dtype.get(key, "int32"), torch.int32
@ -232,24 +166,13 @@ class Pipeline:
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
]
pid_mode = self.config.output.position_ids_mode
if pid_mode and pid_mode != "none" and "sequence" in tensors:
pos_ids = []
if pid_mode == "doc_reset":
for item in keys["sequence"]:
pos_ids.extend(range(len(item)))
else:
total = sum(len(item) for item in keys["sequence"])
pos_ids = list(range(total))
pos_ids = self._position_id.generate(keys.get("sequence", []))
if pos_ids:
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
fmt = self.config.output.storage_format
if fmt == "bin":
save_bin(shard_path, tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
self._writer.save(self.output_dir, domain, idx, tensors)
shard_idx[domain] = idx + 1
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "

View File

@ -0,0 +1,50 @@
"""Position-id generation strategies for packed sequences.
Each strategy takes the list of per-document token sequences after packing
and returns a flat list of position ids (same total length as all
sequences combined). The pipeline wraps the result into a tensor and
attaches it as ``position_ids``.
"""
from abc import ABC, abstractmethod
from typing import List
from astrai.factory import BaseFactory
class PositionIdStrategy(ABC):
"""Generate ``position_ids`` for packed sequences."""
@abstractmethod
def generate(self, sequences: List[list]) -> List[int]: ...
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, PositionIdStrategy):
raise TypeError(
f"{component_cls.__name__} must inherit from PositionIdStrategy"
)
@PositionIdStrategyFactory.register("none")
class NoPositionId(PositionIdStrategy):
def generate(self, sequences):
return []
@PositionIdStrategyFactory.register("doc_reset")
class DocResetPositionId(PositionIdStrategy):
def generate(self, sequences):
pos_ids = []
for seq in sequences:
pos_ids.extend(range(len(seq)))
return pos_ids
@PositionIdStrategyFactory.register("continuous")
class ContinuousPositionId(PositionIdStrategy):
def generate(self, sequences):
total = sum(len(seq) for seq in sequences)
return list(range(total))

View File

@ -0,0 +1,50 @@
"""Storage writer strategies for pipeline output.
The :class:`StoreWriter` abstraction decouples the pipeline from the
concrete storage format (bin / h5). The pipeline builds a ``{key:
List[Tensor]}`` dict and delegates the write to the writer selected
by ``output.storage_format``.
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List
import torch
from astrai.dataset.storage import save_bin, save_h5
from astrai.factory import BaseFactory
class StoreWriter(ABC):
"""Write pre-tokenized tensors to disk in a format-specific way."""
@abstractmethod
def save(
self,
output_dir: str,
domain: str,
shard_idx: int,
tensors: Dict[str, List[torch.Tensor]],
) -> None: ...
class StoreWriterFactory(BaseFactory["StoreWriter"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, StoreWriter):
raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter")
@StoreWriterFactory.register("bin")
class BinWriter(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
save_bin(shard_path, tensors)
@StoreWriterFactory.register("h5")
class H5Writer(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
chunk_dir = os.path.join(output_dir, domain)
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)