feat : 数据流水拼接策略支持 position_ids 预计算
- OutputConfig.position_ids_mode 三种模式控制边界策略 - pipeline._flush() 按配置生成扁平 position_ids 数组 - SFTDataset 在 __getitem__ 中返回 position_ids - SFTStrategy 将 position_ids 传入 model.forward()
This commit is contained in:
parent
5e73ca20aa
commit
985d940db6
|
|
@ -45,6 +45,13 @@ class OutputConfig(BaseConfig):
|
|||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
dtype: Dict[str, str] = field(default_factory=dict)
|
||||
position_ids_mode: Optional[str] = None
|
||||
"""How to compute position_ids in packed sequences.
|
||||
|
||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ class SFTDataset(BaseDataset):
|
|||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["sequence", "loss_mask"]
|
||||
return ["sequence", "loss_mask", "position_ids"]
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
|
@ -231,15 +231,17 @@ class SFTDataset(BaseDataset):
|
|||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence")
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
|
||||
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
|
||||
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||
return {
|
||||
"input_ids": x.to(dtype=torch.long),
|
||||
"target_ids": y.to(dtype=torch.long),
|
||||
"position_ids": position_ids.to(dtype=torch.long),
|
||||
"loss_mask": loss_mask.to(dtype=torch.bool),
|
||||
}
|
||||
|
||||
|
||||
@DatasetFactory.register("dpo")
|
||||
|
|
|
|||
|
|
@ -144,6 +144,7 @@ class Pipeline:
|
|||
def _flush(self, domains, shard_idx):
|
||||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
dt = _STR_TO_DTYPE.get(
|
||||
|
|
@ -152,10 +153,22 @@ class Pipeline:
|
|||
tensors[key] = [
|
||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||
]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
|
||||
pid_mode = self.config.output.position_ids_mode
|
||||
if pid_mode and pid_mode != "none" and "sequence" in tensors:
|
||||
pos_ids = []
|
||||
if pid_mode == "doc_reset":
|
||||
for item in keys["sequence"]:
|
||||
pos_ids.extend(range(len(item)))
|
||||
else:
|
||||
total = sum(len(item) for item in keys["sequence"])
|
||||
pos_ids = list(range(total))
|
||||
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
||||
|
||||
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
|
||||
fmt = self.config.output.storage_format
|
||||
if fmt == "bin":
|
||||
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
|
||||
save_bin(shard_path, tensors)
|
||||
else:
|
||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||
shard_idx[domain] = idx + 1
|
||||
|
|
|
|||
|
|
@ -180,14 +180,15 @@ class SFTStrategy(BaseStrategy):
|
|||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids, loss_mask = (
|
||||
input_ids, target_ids, position_ids, loss_mask = (
|
||||
batch["input_ids"],
|
||||
batch["target_ids"],
|
||||
batch["position_ids"],
|
||||
batch["loss_mask"],
|
||||
)
|
||||
|
||||
ignore_index = -100
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"]
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
|
|
|
|||
Loading…
Reference in New Issue