feat : 数据流水拼接策略支持 position_ids 预计算
- OutputConfig.position_ids_mode 三种模式控制边界策略 - pipeline._flush() 按配置生成扁平 position_ids 数组 - SFTDataset 在 __getitem__ 中返回 position_ids - SFTStrategy 将 position_ids 传入 model.forward()
This commit is contained in:
parent
5e73ca20aa
commit
985d940db6
|
|
@ -45,6 +45,13 @@ class OutputConfig(BaseConfig):
|
||||||
storage_format: str = "bin"
|
storage_format: str = "bin"
|
||||||
max_tokens_per_shard: int = 100_000_000
|
max_tokens_per_shard: int = 100_000_000
|
||||||
dtype: Dict[str, str] = field(default_factory=dict)
|
dtype: Dict[str, str] = field(default_factory=dict)
|
||||||
|
position_ids_mode: Optional[str] = None
|
||||||
|
"""How to compute position_ids in packed sequences.
|
||||||
|
|
||||||
|
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||||
|
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||||
|
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -223,7 +223,7 @@ class SFTDataset(BaseDataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
return ["sequence", "loss_mask"]
|
return ["sequence", "loss_mask", "position_ids"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -231,15 +231,17 @@ class SFTDataset(BaseDataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence")
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
|
||||||
dtype=torch.long
|
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
|
||||||
)
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
return {
|
||||||
|
"input_ids": x.to(dtype=torch.long),
|
||||||
|
"target_ids": y.to(dtype=torch.long),
|
||||||
|
"position_ids": position_ids.to(dtype=torch.long),
|
||||||
|
"loss_mask": loss_mask.to(dtype=torch.bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("dpo")
|
@DatasetFactory.register("dpo")
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,7 @@ class Pipeline:
|
||||||
def _flush(self, domains, shard_idx):
|
def _flush(self, domains, shard_idx):
|
||||||
for domain, keys in domains.items():
|
for domain, keys in domains.items():
|
||||||
idx = shard_idx[domain]
|
idx = shard_idx[domain]
|
||||||
|
chunk_dir = os.path.join(self.output_dir, domain)
|
||||||
tensors = {}
|
tensors = {}
|
||||||
for key, ids_list in keys.items():
|
for key, ids_list in keys.items():
|
||||||
dt = _STR_TO_DTYPE.get(
|
dt = _STR_TO_DTYPE.get(
|
||||||
|
|
@ -152,10 +153,22 @@ class Pipeline:
|
||||||
tensors[key] = [
|
tensors[key] = [
|
||||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||||
]
|
]
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
|
||||||
|
pid_mode = self.config.output.position_ids_mode
|
||||||
|
if pid_mode and pid_mode != "none" and "sequence" in tensors:
|
||||||
|
pos_ids = []
|
||||||
|
if pid_mode == "doc_reset":
|
||||||
|
for item in keys["sequence"]:
|
||||||
|
pos_ids.extend(range(len(item)))
|
||||||
|
else:
|
||||||
|
total = sum(len(item) for item in keys["sequence"])
|
||||||
|
pos_ids = list(range(total))
|
||||||
|
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
||||||
|
|
||||||
|
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
|
||||||
fmt = self.config.output.storage_format
|
fmt = self.config.output.storage_format
|
||||||
if fmt == "bin":
|
if fmt == "bin":
|
||||||
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
|
save_bin(shard_path, tensors)
|
||||||
else:
|
else:
|
||||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||||
shard_idx[domain] = idx + 1
|
shard_idx[domain] = idx + 1
|
||||||
|
|
|
||||||
|
|
@ -180,14 +180,15 @@ class SFTStrategy(BaseStrategy):
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids, loss_mask = (
|
input_ids, target_ids, position_ids, loss_mask = (
|
||||||
batch["input_ids"],
|
batch["input_ids"],
|
||||||
batch["target_ids"],
|
batch["target_ids"],
|
||||||
|
batch["position_ids"],
|
||||||
batch["loss_mask"],
|
batch["loss_mask"],
|
||||||
)
|
)
|
||||||
|
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"]
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue