diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index e42ac2e..02ba923 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -363,6 +363,16 @@ classDiagram class TextMaskBuilder { +build(item, config, tokenizer) Optional[dict] } + + class Pipeline { + +PipelineConfig config + +List[str] paths + +str output_dir + +str tokenizer_path + +BaseMaskBuilder mask_builder + +transform(item) Optional[dict] + +run() + } } namespace tokenize { @@ -1092,6 +1102,8 @@ classDiagram KvcacheView o-- Storage SamplingPipeline o-- BaseSamplingStrategy BaseDataset o-- Store + Pipeline o-- PipelineConfig + Pipeline o-- BaseMaskBuilder %% --- Dependency (uses temporarily) --- TrainConfig ..> BaseStrategy : selects diff --git a/assets/docs/preprocessing.md b/assets/docs/preprocessing.md index ff983a9..995574e 100644 --- a/assets/docs/preprocessing.md +++ b/assets/docs/preprocessing.md @@ -186,6 +186,8 @@ Pure tokenization. No `loss_mask` is produced. Used for pretraining. ## Output Layout +### Single-Shard (`bin`) + ``` output_dir/ __default__/ # when domain_key is null @@ -198,6 +200,59 @@ output_dir/ loss_mask.bin ``` +### Multi-Shard (`bin`) + +When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories: + +``` +output_dir/ + __default__/ + shard_0000/ + meta.json + sequence.bin + loss_mask.bin + shard_0001/ + meta.json + sequence.bin + loss_mask.bin +``` + +`MmapStore` automatically discovers and merges all shards under the domain directory. + +### H5 Output + +HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`: + +``` +output_dir/ + __default__/ + data_0000.h5 # each H5 contains key→dataset groups + data_0001.h5 + wiki/ + data_0000.h5 +``` + +## Python API Usage + +```python +from astrai.preprocessing.pipeline import Pipeline +from astrai.config.preprocess_config import PipelineConfig + +config = PipelineConfig.from_json("sft_pipeline.json") +Pipeline( + config, + ["data_part1.jsonl", "data_part2.jsonl"], + output_dir="output/", + tokenizer_path="params" +).run() +``` + +Or from the CLI: + +```bash +python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json +``` + ## Extension Register a custom builder for new formats: diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index 73fc74f..cf9a7ed 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -117,8 +117,12 @@ def detect_format(load_path: str) -> str: if h5_files: return "h5" bin_files = list(root.rglob("*.bin")) - if bin_files and (root / "meta.json").exists(): - return "bin" + if bin_files: + has_meta = (root / "meta.json").exists() or len( + list(root.rglob("meta.json")) + ) > 0 + if has_meta: + return "bin" raise FileNotFoundError(f"No supported data files found at {load_path}") @@ -244,7 +248,17 @@ class MmapStore(Store): def load(self, path: str): self._mmap_refs = [] - raw = load_bin(path) - self._normalize(raw) + root = Path(path) + all_raw: Dict[str, List[Tensor]] = {} + meta_paths = list(root.rglob("meta.json")) + for meta_path in meta_paths: + raw = load_bin(str(meta_path.parent)) + for key, tensors in raw.items(): + if key not in all_raw: + all_raw[key] = [] + all_raw[key].extend(tensors) + if not meta_paths: + raise FileNotFoundError(f"No meta.json found under {path}") + self._normalize(all_raw) for tensors in self._data.values(): self._mmap_refs.extend(tensors) diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py index 526554a..9507bd7 100644 --- a/astrai/inference/api/anthropic.py +++ b/astrai/inference/api/anthropic.py @@ -1,5 +1,6 @@ """Anthropic message completion response builder.""" +import time import uuid from typing import Any, Dict, List, Tuple, Union @@ -39,7 +40,7 @@ class AnthropicResponseBuilder(ResponseBuilder): prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) ctx = GenContext( resp_id=f"msg_{uuid.uuid4().hex[:24]}", - created=0, + created=int(time.time()), model=request.model, prompt_tokens=0, ) diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py index 5e86437..a8ca51d 100644 --- a/astrai/inference/api/openai.py +++ b/astrai/inference/api/openai.py @@ -1,5 +1,7 @@ """OpenAI chat completion response builder.""" +import logging +import time import uuid from typing import Any, Dict, List, Tuple @@ -13,6 +15,16 @@ from astrai.inference.api.protocol import ( ) from astrai.inference.engine import InferenceEngine +logger = logging.getLogger(__name__) + +_UNSUPPORTED_PARAMS = ( + "n", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", +) + class OpenAIResponseBuilder(ResponseBuilder): def prepare( @@ -24,9 +36,26 @@ class OpenAIResponseBuilder(ResponseBuilder): self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" self._model = request.model + for param in _UNSUPPORTED_PARAMS: + value = getattr(request, param, None) + fields = getattr(type(request), "model_fields", {}) + default = fields[param].default if param in fields else None + if value is not None and value != default: + logger.warning( + "ChatCompletionRequest param '%s'=%r is not supported and will be ignored", + param, + value, + ) + if value is not None and value != default: + logger.warning( + "ChatCompletionRequest param '%s'=%r is not supported and will be ignored", + param, + value, + ) + ctx = GenContext( resp_id=self._resp_id, - created=0, + created=int(time.time()), model=self._model, prompt_tokens=0, ) diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index d6c3769..dd7a405 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -64,7 +64,7 @@ class StopChecker: class ResponseBuilder(ABC): """Interface for protocol-specific response formatting. - A new protocol requires one concrete builder implementing 6 methods. + A new protocol requires one concrete builder implementing 5 methods. """ @abstractmethod diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index b7f1554..24fc209 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -124,7 +124,7 @@ class Pipeline: chunk_dir = os.path.join(self.output_dir, domain) fmt = self.config.output.storage_format if fmt == "bin": - save_bin(chunk_dir, tensors) + save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors) else: save_h5(chunk_dir, f"data_{idx:04d}", tensors) shard_idx[domain] = idx + 1 diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py index 7785110..85a1368 100644 --- a/tests/data/test_preprocess.py +++ b/tests/data/test_preprocess.py @@ -451,7 +451,7 @@ class TestPipeline: tokenizer_path=tokenizer_dir, ).run() - meta_path = os.path.join(out_dir, "__default__", "meta.json") + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f) @@ -505,7 +505,7 @@ class TestPipeline: tokenizer_path=tokenizer_dir, ).run() - meta_path = os.path.join(out_dir, "__default__", "meta.json") + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f) @@ -560,7 +560,7 @@ class TestPipeline: tokenizer_path=tokenizer_dir, ).run() - meta_path = os.path.join(out_dir, "__default__", "meta.json") + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f)