fix : 修复 created 时间戳、bin 多 shard 覆盖与文档遗漏

- openai.py/anthropic.py: created 从 0 改为 int(time.time())
- openai.py: ChatCompletionRequest 不支持参数非默认值时 warning
- pipeline.py: bin 多 shard 使用子目录避免静默覆盖
- storage.py: MmapStore/detect_format 支持多 shard 聚合加载
- architecture.md: mermaid 类图新增 Pipeline 类
- preprocessing.md: 新增多 shard 输出布局与 Python API 示例
- protocol.py: docstring "6 methods" 改为 "5 methods"
This commit is contained in:
ViperEkura 2026-05-30 22:56:29 +08:00
parent 1c2ff05a6d
commit 2a65c3314c
8 changed files with 122 additions and 11 deletions

View File

@ -363,6 +363,16 @@ classDiagram
class TextMaskBuilder { class TextMaskBuilder {
+build(item, config, tokenizer) Optional[dict] +build(item, config, tokenizer) Optional[dict]
} }
class Pipeline {
+PipelineConfig config
+List[str] paths
+str output_dir
+str tokenizer_path
+BaseMaskBuilder mask_builder
+transform(item) Optional[dict]
+run()
}
} }
namespace tokenize { namespace tokenize {
@ -1092,6 +1102,8 @@ classDiagram
KvcacheView o-- Storage KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- Store BaseDataset o-- Store
Pipeline o-- PipelineConfig
Pipeline o-- BaseMaskBuilder
%% --- Dependency (uses temporarily) --- %% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects TrainConfig ..> BaseStrategy : selects

View File

@ -186,6 +186,8 @@ Pure tokenization. No `loss_mask` is produced. Used for pretraining.
## Output Layout ## Output Layout
### Single-Shard (`bin`)
``` ```
output_dir/ output_dir/
__default__/ # when domain_key is null __default__/ # when domain_key is null
@ -198,6 +200,59 @@ output_dir/
loss_mask.bin loss_mask.bin
``` ```
### Multi-Shard (`bin`)
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
```
output_dir/
__default__/
shard_0000/
meta.json
sequence.bin
loss_mask.bin
shard_0001/
meta.json
sequence.bin
loss_mask.bin
```
`MmapStore` automatically discovers and merges all shards under the domain directory.
### H5 Output
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
```
output_dir/
__default__/
data_0000.h5 # each H5 contains key→dataset groups
data_0001.h5
wiki/
data_0000.h5
```
## Python API Usage
```python
from astrai.preprocessing.pipeline import Pipeline
from astrai.config.preprocess_config import PipelineConfig
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(
config,
["data_part1.jsonl", "data_part2.jsonl"],
output_dir="output/",
tokenizer_path="params"
).run()
```
Or from the CLI:
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
## Extension ## Extension
Register a custom builder for new formats: Register a custom builder for new formats:

View File

@ -117,8 +117,12 @@ def detect_format(load_path: str) -> str:
if h5_files: if h5_files:
return "h5" return "h5"
bin_files = list(root.rglob("*.bin")) bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists(): if bin_files:
return "bin" has_meta = (root / "meta.json").exists() or len(
list(root.rglob("meta.json"))
) > 0
if has_meta:
return "bin"
raise FileNotFoundError(f"No supported data files found at {load_path}") raise FileNotFoundError(f"No supported data files found at {load_path}")
@ -244,7 +248,17 @@ class MmapStore(Store):
def load(self, path: str): def load(self, path: str):
self._mmap_refs = [] self._mmap_refs = []
raw = load_bin(path) root = Path(path)
self._normalize(raw) all_raw: Dict[str, List[Tensor]] = {}
meta_paths = list(root.rglob("meta.json"))
for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items():
if key not in all_raw:
all_raw[key] = []
all_raw[key].extend(tensors)
if not meta_paths:
raise FileNotFoundError(f"No meta.json found under {path}")
self._normalize(all_raw)
for tensors in self._data.values(): for tensors in self._data.values():
self._mmap_refs.extend(tensors) self._mmap_refs.extend(tensors)

View File

@ -1,5 +1,6 @@
"""Anthropic message completion response builder.""" """Anthropic message completion response builder."""
import time
import uuid import uuid
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
@ -39,7 +40,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
ctx = GenContext( ctx = GenContext(
resp_id=f"msg_{uuid.uuid4().hex[:24]}", resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=0, created=int(time.time()),
model=request.model, model=request.model,
prompt_tokens=0, prompt_tokens=0,
) )

View File

@ -1,5 +1,7 @@
"""OpenAI chat completion response builder.""" """OpenAI chat completion response builder."""
import logging
import time
import uuid import uuid
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
@ -13,6 +15,16 @@ from astrai.inference.api.protocol import (
) )
from astrai.inference.engine import InferenceEngine from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__)
_UNSUPPORTED_PARAMS = (
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
)
class OpenAIResponseBuilder(ResponseBuilder): class OpenAIResponseBuilder(ResponseBuilder):
def prepare( def prepare(
@ -24,9 +36,26 @@ class OpenAIResponseBuilder(ResponseBuilder):
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model self._model = request.model
for param in _UNSUPPORTED_PARAMS:
value = getattr(request, param, None)
fields = getattr(type(request), "model_fields", {})
default = fields[param].default if param in fields else None
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
ctx = GenContext( ctx = GenContext(
resp_id=self._resp_id, resp_id=self._resp_id,
created=0, created=int(time.time()),
model=self._model, model=self._model,
prompt_tokens=0, prompt_tokens=0,
) )

View File

@ -64,7 +64,7 @@ class StopChecker:
class ResponseBuilder(ABC): class ResponseBuilder(ABC):
"""Interface for protocol-specific response formatting. """Interface for protocol-specific response formatting.
A new protocol requires one concrete builder implementing 6 methods. A new protocol requires one concrete builder implementing 5 methods.
""" """
@abstractmethod @abstractmethod

View File

@ -124,7 +124,7 @@ class Pipeline:
chunk_dir = os.path.join(self.output_dir, domain) chunk_dir = os.path.join(self.output_dir, domain)
fmt = self.config.output.storage_format fmt = self.config.output.storage_format
if fmt == "bin": if fmt == "bin":
save_bin(chunk_dir, tensors) save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
else: else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors) save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1 shard_idx[domain] = idx + 1

View File

@ -451,7 +451,7 @@ class TestPipeline:
tokenizer_path=tokenizer_dir, tokenizer_path=tokenizer_dir,
).run() ).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json") meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path) assert os.path.exists(meta_path)
with open(meta_path, "r") as f: with open(meta_path, "r") as f:
meta = json.load(f) meta = json.load(f)
@ -505,7 +505,7 @@ class TestPipeline:
tokenizer_path=tokenizer_dir, tokenizer_path=tokenizer_dir,
).run() ).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json") meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path) assert os.path.exists(meta_path)
with open(meta_path, "r") as f: with open(meta_path, "r") as f:
meta = json.load(f) meta = json.load(f)
@ -560,7 +560,7 @@ class TestPipeline:
tokenizer_path=tokenizer_dir, tokenizer_path=tokenizer_dir,
).run() ).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json") meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path) assert os.path.exists(meta_path)
with open(meta_path, "r") as f: with open(meta_path, "r") as f:
meta = json.load(f) meta = json.load(f)