From 2a65c3314cce981c639e6605beaedc63578199e7 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 30 May 2026 22:56:29 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20=E4=BF=AE=E5=A4=8D=20created=20?= =?UTF-8?q?=E6=97=B6=E9=97=B4=E6=88=B3=E3=80=81bin=20=E5=A4=9A=20shard=20?= =?UTF-8?q?=E8=A6=86=E7=9B=96=E4=B8=8E=E6=96=87=E6=A1=A3=E9=81=97=E6=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - openai.py/anthropic.py: created 从 0 改为 int(time.time()) - openai.py: ChatCompletionRequest 不支持参数非默认值时 warning - pipeline.py: bin 多 shard 使用子目录避免静默覆盖 - storage.py: MmapStore/detect_format 支持多 shard 聚合加载 - architecture.md: mermaid 类图新增 Pipeline 类 - preprocessing.md: 新增多 shard 输出布局与 Python API 示例 - protocol.py: docstring "6 methods" 改为 "5 methods" --- assets/docs/architecture.md | 12 +++++++ assets/docs/preprocessing.md | 55 +++++++++++++++++++++++++++++++ astrai/dataset/storage.py | 22 ++++++++++--- astrai/inference/api/anthropic.py | 3 +- astrai/inference/api/openai.py | 31 ++++++++++++++++- astrai/inference/api/protocol.py | 2 +- astrai/preprocessing/pipeline.py | 2 +- tests/data/test_preprocess.py | 6 ++-- 8 files changed, 122 insertions(+), 11 deletions(-) diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index e42ac2e..02ba923 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -363,6 +363,16 @@ classDiagram class TextMaskBuilder { +build(item, config, tokenizer) Optional[dict] } + + class Pipeline { + +PipelineConfig config + +List[str] paths + +str output_dir + +str tokenizer_path + +BaseMaskBuilder mask_builder + +transform(item) Optional[dict] + +run() + } } namespace tokenize { @@ -1092,6 +1102,8 @@ classDiagram KvcacheView o-- Storage SamplingPipeline o-- BaseSamplingStrategy BaseDataset o-- Store + Pipeline o-- PipelineConfig + Pipeline o-- BaseMaskBuilder %% --- Dependency (uses temporarily) --- TrainConfig ..> BaseStrategy : selects diff --git a/assets/docs/preprocessing.md b/assets/docs/preprocessing.md index ff983a9..995574e 100644 --- a/assets/docs/preprocessing.md +++ b/assets/docs/preprocessing.md @@ -186,6 +186,8 @@ Pure tokenization. No `loss_mask` is produced. Used for pretraining. ## Output Layout +### Single-Shard (`bin`) + ``` output_dir/ __default__/ # when domain_key is null @@ -198,6 +200,59 @@ output_dir/ loss_mask.bin ``` +### Multi-Shard (`bin`) + +When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories: + +``` +output_dir/ + __default__/ + shard_0000/ + meta.json + sequence.bin + loss_mask.bin + shard_0001/ + meta.json + sequence.bin + loss_mask.bin +``` + +`MmapStore` automatically discovers and merges all shards under the domain directory. + +### H5 Output + +HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`: + +``` +output_dir/ + __default__/ + data_0000.h5 # each H5 contains key→dataset groups + data_0001.h5 + wiki/ + data_0000.h5 +``` + +## Python API Usage + +```python +from astrai.preprocessing.pipeline import Pipeline +from astrai.config.preprocess_config import PipelineConfig + +config = PipelineConfig.from_json("sft_pipeline.json") +Pipeline( + config, + ["data_part1.jsonl", "data_part2.jsonl"], + output_dir="output/", + tokenizer_path="params" +).run() +``` + +Or from the CLI: + +```bash +python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json +``` + ## Extension Register a custom builder for new formats: diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index 73fc74f..cf9a7ed 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -117,8 +117,12 @@ def detect_format(load_path: str) -> str: if h5_files: return "h5" bin_files = list(root.rglob("*.bin")) - if bin_files and (root / "meta.json").exists(): - return "bin" + if bin_files: + has_meta = (root / "meta.json").exists() or len( + list(root.rglob("meta.json")) + ) > 0 + if has_meta: + return "bin" raise FileNotFoundError(f"No supported data files found at {load_path}") @@ -244,7 +248,17 @@ class MmapStore(Store): def load(self, path: str): self._mmap_refs = [] - raw = load_bin(path) - self._normalize(raw) + root = Path(path) + all_raw: Dict[str, List[Tensor]] = {} + meta_paths = list(root.rglob("meta.json")) + for meta_path in meta_paths: + raw = load_bin(str(meta_path.parent)) + for key, tensors in raw.items(): + if key not in all_raw: + all_raw[key] = [] + all_raw[key].extend(tensors) + if not meta_paths: + raise FileNotFoundError(f"No meta.json found under {path}") + self._normalize(all_raw) for tensors in self._data.values(): self._mmap_refs.extend(tensors) diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py index 526554a..9507bd7 100644 --- a/astrai/inference/api/anthropic.py +++ b/astrai/inference/api/anthropic.py @@ -1,5 +1,6 @@ """Anthropic message completion response builder.""" +import time import uuid from typing import Any, Dict, List, Tuple, Union @@ -39,7 +40,7 @@ class AnthropicResponseBuilder(ResponseBuilder): prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) ctx = GenContext( resp_id=f"msg_{uuid.uuid4().hex[:24]}", - created=0, + created=int(time.time()), model=request.model, prompt_tokens=0, ) diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py index 5e86437..a8ca51d 100644 --- a/astrai/inference/api/openai.py +++ b/astrai/inference/api/openai.py @@ -1,5 +1,7 @@ """OpenAI chat completion response builder.""" +import logging +import time import uuid from typing import Any, Dict, List, Tuple @@ -13,6 +15,16 @@ from astrai.inference.api.protocol import ( ) from astrai.inference.engine import InferenceEngine +logger = logging.getLogger(__name__) + +_UNSUPPORTED_PARAMS = ( + "n", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", +) + class OpenAIResponseBuilder(ResponseBuilder): def prepare( @@ -24,9 +36,26 @@ class OpenAIResponseBuilder(ResponseBuilder): self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" self._model = request.model + for param in _UNSUPPORTED_PARAMS: + value = getattr(request, param, None) + fields = getattr(type(request), "model_fields", {}) + default = fields[param].default if param in fields else None + if value is not None and value != default: + logger.warning( + "ChatCompletionRequest param '%s'=%r is not supported and will be ignored", + param, + value, + ) + if value is not None and value != default: + logger.warning( + "ChatCompletionRequest param '%s'=%r is not supported and will be ignored", + param, + value, + ) + ctx = GenContext( resp_id=self._resp_id, - created=0, + created=int(time.time()), model=self._model, prompt_tokens=0, ) diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index d6c3769..dd7a405 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -64,7 +64,7 @@ class StopChecker: class ResponseBuilder(ABC): """Interface for protocol-specific response formatting. - A new protocol requires one concrete builder implementing 6 methods. + A new protocol requires one concrete builder implementing 5 methods. """ @abstractmethod diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index b7f1554..24fc209 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -124,7 +124,7 @@ class Pipeline: chunk_dir = os.path.join(self.output_dir, domain) fmt = self.config.output.storage_format if fmt == "bin": - save_bin(chunk_dir, tensors) + save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors) else: save_h5(chunk_dir, f"data_{idx:04d}", tensors) shard_idx[domain] = idx + 1 diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py index 7785110..85a1368 100644 --- a/tests/data/test_preprocess.py +++ b/tests/data/test_preprocess.py @@ -451,7 +451,7 @@ class TestPipeline: tokenizer_path=tokenizer_dir, ).run() - meta_path = os.path.join(out_dir, "__default__", "meta.json") + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f) @@ -505,7 +505,7 @@ class TestPipeline: tokenizer_path=tokenizer_dir, ).run() - meta_path = os.path.join(out_dir, "__default__", "meta.json") + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f) @@ -560,7 +560,7 @@ class TestPipeline: tokenizer_path=tokenizer_dir, ).run() - meta_path = os.path.join(out_dir, "__default__", "meta.json") + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f)