fix : 修复 created 时间戳、bin 多 shard 覆盖与文档遗漏

- openai.py/anthropic.py: created 从 0 改为 int(time.time())
- openai.py: ChatCompletionRequest 不支持参数非默认值时 warning
- pipeline.py: bin 多 shard 使用子目录避免静默覆盖
- storage.py: MmapStore/detect_format 支持多 shard 聚合加载
- architecture.md: mermaid 类图新增 Pipeline 类
- preprocessing.md: 新增多 shard 输出布局与 Python API 示例
- protocol.py: docstring "6 methods" 改为 "5 methods"
This commit is contained in:
ViperEkura 2026-05-30 22:56:29 +08:00
parent 1c2ff05a6d
commit 2a65c3314c
8 changed files with 122 additions and 11 deletions

View File

@ -363,6 +363,16 @@ classDiagram
class TextMaskBuilder {
+build(item, config, tokenizer) Optional[dict]
}
class Pipeline {
+PipelineConfig config
+List[str] paths
+str output_dir
+str tokenizer_path
+BaseMaskBuilder mask_builder
+transform(item) Optional[dict]
+run()
}
}
namespace tokenize {
@ -1092,6 +1102,8 @@ classDiagram
KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- Store
Pipeline o-- PipelineConfig
Pipeline o-- BaseMaskBuilder
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects

View File

@ -186,6 +186,8 @@ Pure tokenization. No `loss_mask` is produced. Used for pretraining.
## Output Layout
### Single-Shard (`bin`)
```
output_dir/
__default__/ # when domain_key is null
@ -198,6 +200,59 @@ output_dir/
loss_mask.bin
```
### Multi-Shard (`bin`)
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
```
output_dir/
__default__/
shard_0000/
meta.json
sequence.bin
loss_mask.bin
shard_0001/
meta.json
sequence.bin
loss_mask.bin
```
`MmapStore` automatically discovers and merges all shards under the domain directory.
### H5 Output
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
```
output_dir/
__default__/
data_0000.h5 # each H5 contains key→dataset groups
data_0001.h5
wiki/
data_0000.h5
```
## Python API Usage
```python
from astrai.preprocessing.pipeline import Pipeline
from astrai.config.preprocess_config import PipelineConfig
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(
config,
["data_part1.jsonl", "data_part2.jsonl"],
output_dir="output/",
tokenizer_path="params"
).run()
```
Or from the CLI:
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
## Extension
Register a custom builder for new formats:

View File

@ -117,8 +117,12 @@ def detect_format(load_path: str) -> str:
if h5_files:
return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
if bin_files:
has_meta = (root / "meta.json").exists() or len(
list(root.rglob("meta.json"))
) > 0
if has_meta:
return "bin"
raise FileNotFoundError(f"No supported data files found at {load_path}")
@ -244,7 +248,17 @@ class MmapStore(Store):
def load(self, path: str):
self._mmap_refs = []
raw = load_bin(path)
self._normalize(raw)
root = Path(path)
all_raw: Dict[str, List[Tensor]] = {}
meta_paths = list(root.rglob("meta.json"))
for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items():
if key not in all_raw:
all_raw[key] = []
all_raw[key].extend(tensors)
if not meta_paths:
raise FileNotFoundError(f"No meta.json found under {path}")
self._normalize(all_raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)

View File

@ -1,5 +1,6 @@
"""Anthropic message completion response builder."""
import time
import uuid
from typing import Any, Dict, List, Tuple, Union
@ -39,7 +40,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
ctx = GenContext(
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=0,
created=int(time.time()),
model=request.model,
prompt_tokens=0,
)

View File

@ -1,5 +1,7 @@
"""OpenAI chat completion response builder."""
import logging
import time
import uuid
from typing import Any, Dict, List, Tuple
@ -13,6 +15,16 @@ from astrai.inference.api.protocol import (
)
from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__)
_UNSUPPORTED_PARAMS = (
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
)
class OpenAIResponseBuilder(ResponseBuilder):
def prepare(
@ -24,9 +36,26 @@ class OpenAIResponseBuilder(ResponseBuilder):
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model
for param in _UNSUPPORTED_PARAMS:
value = getattr(request, param, None)
fields = getattr(type(request), "model_fields", {})
default = fields[param].default if param in fields else None
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
ctx = GenContext(
resp_id=self._resp_id,
created=0,
created=int(time.time()),
model=self._model,
prompt_tokens=0,
)

View File

@ -64,7 +64,7 @@ class StopChecker:
class ResponseBuilder(ABC):
"""Interface for protocol-specific response formatting.
A new protocol requires one concrete builder implementing 6 methods.
A new protocol requires one concrete builder implementing 5 methods.
"""
@abstractmethod

View File

@ -124,7 +124,7 @@ class Pipeline:
chunk_dir = os.path.join(self.output_dir, domain)
fmt = self.config.output.storage_format
if fmt == "bin":
save_bin(chunk_dir, tensors)
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1

View File

@ -451,7 +451,7 @@ class TestPipeline:
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
@ -505,7 +505,7 @@ class TestPipeline:
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
@ -560,7 +560,7 @@ class TestPipeline:
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)