diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py index 5deac30..8d08bd9 100644 --- a/astrai/config/preprocess_config.py +++ b/astrai/config/preprocess_config.py @@ -45,6 +45,13 @@ class OutputConfig(BaseConfig): storage_format: str = "bin" max_tokens_per_shard: int = 100_000_000 dtype: Dict[str, str] = field(default_factory=dict) + position_ids_mode: Optional[str] = None + """How to compute position_ids in packed sequences. + + - ``None`` / ``"none"``: do not generate (backward compatible). + - ``"doc_reset"``: reset to 0 at each document boundary. + - ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc). + """ @dataclass diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 0251e07..fedf1e5 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -223,7 +223,7 @@ class SFTDataset(BaseDataset): @property def required_keys(self) -> List[str]: - return ["sequence", "loss_mask"] + return ["sequence", "loss_mask", "position_ids"] def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.storage.fetch(begin_idx, end_idx, key) @@ -231,15 +231,17 @@ class SFTDataset(BaseDataset): def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) - x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) - y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to( - dtype=torch.long - ) - loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to( - dtype=torch.bool - ) + x = self._fetch_data(begin_idx, end_idx, "sequence") + y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence") + position_ids = self._fetch_data(begin_idx, end_idx, "position_ids") + loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask") - return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} + return { + "input_ids": x.to(dtype=torch.long), + "target_ids": y.to(dtype=torch.long), + "position_ids": position_ids.to(dtype=torch.long), + "loss_mask": loss_mask.to(dtype=torch.bool), + } @DatasetFactory.register("dpo") diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index 5e16541..9ebf926 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -144,6 +144,7 @@ class Pipeline: def _flush(self, domains, shard_idx): for domain, keys in domains.items(): idx = shard_idx[domain] + chunk_dir = os.path.join(self.output_dir, domain) tensors = {} for key, ids_list in keys.items(): dt = _STR_TO_DTYPE.get( @@ -152,10 +153,22 @@ class Pipeline: tensors[key] = [ torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt) ] - chunk_dir = os.path.join(self.output_dir, domain) + + pid_mode = self.config.output.position_ids_mode + if pid_mode and pid_mode != "none" and "sequence" in tensors: + pos_ids = [] + if pid_mode == "doc_reset": + for item in keys["sequence"]: + pos_ids.extend(range(len(item))) + else: + total = sum(len(item) for item in keys["sequence"]) + pos_ids = list(range(total)) + tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)] + + shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}") fmt = self.config.output.storage_format if fmt == "bin": - save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors) + save_bin(shard_path, tensors) else: save_h5(chunk_dir, f"data_{idx:04d}", tensors) shard_idx[domain] = idx + 1 diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 529edd1..63472e4 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -180,14 +180,15 @@ class SFTStrategy(BaseStrategy): def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) - input_ids, target_ids, loss_mask = ( + input_ids, target_ids, position_ids, loss_mask = ( batch["input_ids"], batch["target_ids"], + batch["position_ids"], batch["loss_mask"], ) ignore_index = -100 - logits = self.model(input_ids=input_ids)["logits"] + logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy(