feat : 数据流水拼接策略支持 position_ids 预计算

- OutputConfig.position_ids_mode 三种模式控制边界策略
- pipeline._flush() 按配置生成扁平 position_ids 数组
- SFTDataset 在 __getitem__ 中返回 position_ids
- SFTStrategy 将 position_ids 传入 model.forward()
This commit is contained in:
ViperEkura 2026-06-04 13:56:19 +08:00
parent 5e73ca20aa
commit 985d940db6
4 changed files with 36 additions and 13 deletions

View File

@ -45,6 +45,13 @@ class OutputConfig(BaseConfig):
storage_format: str = "bin" storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000 max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict) dtype: Dict[str, str] = field(default_factory=dict)
position_ids_mode: Optional[str] = None
"""How to compute position_ids in packed sequences.
- ``None`` / ``"none"``: do not generate (backward compatible).
- ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
"""
@dataclass @dataclass

View File

@ -223,7 +223,7 @@ class SFTDataset(BaseDataset):
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
return ["sequence", "loss_mask"] return ["sequence", "loss_mask", "position_ids"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key) return self.storage.fetch(begin_idx, end_idx, key)
@ -231,15 +231,17 @@ class SFTDataset(BaseDataset):
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx, "sequence")
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to( y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
dtype=torch.long position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
) loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} return {
"input_ids": x.to(dtype=torch.long),
"target_ids": y.to(dtype=torch.long),
"position_ids": position_ids.to(dtype=torch.long),
"loss_mask": loss_mask.to(dtype=torch.bool),
}
@DatasetFactory.register("dpo") @DatasetFactory.register("dpo")

View File

@ -144,6 +144,7 @@ class Pipeline:
def _flush(self, domains, shard_idx): def _flush(self, domains, shard_idx):
for domain, keys in domains.items(): for domain, keys in domains.items():
idx = shard_idx[domain] idx = shard_idx[domain]
chunk_dir = os.path.join(self.output_dir, domain)
tensors = {} tensors = {}
for key, ids_list in keys.items(): for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get( dt = _STR_TO_DTYPE.get(
@ -152,10 +153,22 @@ class Pipeline:
tensors[key] = [ tensors[key] = [
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt) torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
] ]
chunk_dir = os.path.join(self.output_dir, domain)
pid_mode = self.config.output.position_ids_mode
if pid_mode and pid_mode != "none" and "sequence" in tensors:
pos_ids = []
if pid_mode == "doc_reset":
for item in keys["sequence"]:
pos_ids.extend(range(len(item)))
else:
total = sum(len(item) for item in keys["sequence"])
pos_ids = list(range(total))
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
fmt = self.config.output.storage_format fmt = self.config.output.storage_format
if fmt == "bin": if fmt == "bin":
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors) save_bin(shard_path, tensors)
else: else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors) save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1 shard_idx[domain] = idx + 1

View File

@ -180,14 +180,15 @@ class SFTStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
input_ids, target_ids, loss_mask = ( input_ids, target_ids, position_ids, loss_mask = (
batch["input_ids"], batch["input_ids"],
batch["target_ids"], batch["target_ids"],
batch["position_ids"],
batch["loss_mask"], batch["loss_mask"],
) )
ignore_index = -100 ignore_index = -100
logits = self.model(input_ids=input_ids)["logits"] logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy( loss = F.cross_entropy(