fix : 修复 created 时间戳、bin 多 shard 覆盖与文档遗漏
- openai.py/anthropic.py: created 从 0 改为 int(time.time()) - openai.py: ChatCompletionRequest 不支持参数非默认值时 warning - pipeline.py: bin 多 shard 使用子目录避免静默覆盖 - storage.py: MmapStore/detect_format 支持多 shard 聚合加载 - architecture.md: mermaid 类图新增 Pipeline 类 - preprocessing.md: 新增多 shard 输出布局与 Python API 示例 - protocol.py: docstring "6 methods" 改为 "5 methods"
This commit is contained in:
parent
1c2ff05a6d
commit
2a65c3314c
|
|
@ -363,6 +363,16 @@ classDiagram
|
||||||
class TextMaskBuilder {
|
class TextMaskBuilder {
|
||||||
+build(item, config, tokenizer) Optional[dict]
|
+build(item, config, tokenizer) Optional[dict]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class Pipeline {
|
||||||
|
+PipelineConfig config
|
||||||
|
+List[str] paths
|
||||||
|
+str output_dir
|
||||||
|
+str tokenizer_path
|
||||||
|
+BaseMaskBuilder mask_builder
|
||||||
|
+transform(item) Optional[dict]
|
||||||
|
+run()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
|
|
@ -1092,6 +1102,8 @@ classDiagram
|
||||||
KvcacheView o-- Storage
|
KvcacheView o-- Storage
|
||||||
SamplingPipeline o-- BaseSamplingStrategy
|
SamplingPipeline o-- BaseSamplingStrategy
|
||||||
BaseDataset o-- Store
|
BaseDataset o-- Store
|
||||||
|
Pipeline o-- PipelineConfig
|
||||||
|
Pipeline o-- BaseMaskBuilder
|
||||||
|
|
||||||
%% --- Dependency (uses temporarily) ---
|
%% --- Dependency (uses temporarily) ---
|
||||||
TrainConfig ..> BaseStrategy : selects
|
TrainConfig ..> BaseStrategy : selects
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,8 @@ Pure tokenization. No `loss_mask` is produced. Used for pretraining.
|
||||||
|
|
||||||
## Output Layout
|
## Output Layout
|
||||||
|
|
||||||
|
### Single-Shard (`bin`)
|
||||||
|
|
||||||
```
|
```
|
||||||
output_dir/
|
output_dir/
|
||||||
__default__/ # when domain_key is null
|
__default__/ # when domain_key is null
|
||||||
|
|
@ -198,6 +200,59 @@ output_dir/
|
||||||
loss_mask.bin
|
loss_mask.bin
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Multi-Shard (`bin`)
|
||||||
|
|
||||||
|
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
|
||||||
|
|
||||||
|
```
|
||||||
|
output_dir/
|
||||||
|
__default__/
|
||||||
|
shard_0000/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
shard_0001/
|
||||||
|
meta.json
|
||||||
|
sequence.bin
|
||||||
|
loss_mask.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
`MmapStore` automatically discovers and merges all shards under the domain directory.
|
||||||
|
|
||||||
|
### H5 Output
|
||||||
|
|
||||||
|
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
|
||||||
|
|
||||||
|
```
|
||||||
|
output_dir/
|
||||||
|
__default__/
|
||||||
|
data_0000.h5 # each H5 contains key→dataset groups
|
||||||
|
data_0001.h5
|
||||||
|
wiki/
|
||||||
|
data_0000.h5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Python API Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline
|
||||||
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
|
|
||||||
|
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||||
|
Pipeline(
|
||||||
|
config,
|
||||||
|
["data_part1.jsonl", "data_part2.jsonl"],
|
||||||
|
output_dir="output/",
|
||||||
|
tokenizer_path="params"
|
||||||
|
).run()
|
||||||
|
```
|
||||||
|
|
||||||
|
Or from the CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
||||||
|
```
|
||||||
|
|
||||||
## Extension
|
## Extension
|
||||||
|
|
||||||
Register a custom builder for new formats:
|
Register a custom builder for new formats:
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,11 @@ def detect_format(load_path: str) -> str:
|
||||||
if h5_files:
|
if h5_files:
|
||||||
return "h5"
|
return "h5"
|
||||||
bin_files = list(root.rglob("*.bin"))
|
bin_files = list(root.rglob("*.bin"))
|
||||||
if bin_files and (root / "meta.json").exists():
|
if bin_files:
|
||||||
|
has_meta = (root / "meta.json").exists() or len(
|
||||||
|
list(root.rglob("meta.json"))
|
||||||
|
) > 0
|
||||||
|
if has_meta:
|
||||||
return "bin"
|
return "bin"
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
@ -244,7 +248,17 @@ class MmapStore(Store):
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(self, path: str):
|
||||||
self._mmap_refs = []
|
self._mmap_refs = []
|
||||||
raw = load_bin(path)
|
root = Path(path)
|
||||||
self._normalize(raw)
|
all_raw: Dict[str, List[Tensor]] = {}
|
||||||
|
meta_paths = list(root.rglob("meta.json"))
|
||||||
|
for meta_path in meta_paths:
|
||||||
|
raw = load_bin(str(meta_path.parent))
|
||||||
|
for key, tensors in raw.items():
|
||||||
|
if key not in all_raw:
|
||||||
|
all_raw[key] = []
|
||||||
|
all_raw[key].extend(tensors)
|
||||||
|
if not meta_paths:
|
||||||
|
raise FileNotFoundError(f"No meta.json found under {path}")
|
||||||
|
self._normalize(all_raw)
|
||||||
for tensors in self._data.values():
|
for tensors in self._data.values():
|
||||||
self._mmap_refs.extend(tensors)
|
self._mmap_refs.extend(tensors)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Anthropic message completion response builder."""
|
"""Anthropic message completion response builder."""
|
||||||
|
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
|
@ -39,7 +40,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
ctx = GenContext(
|
ctx = GenContext(
|
||||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||||
created=0,
|
created=int(time.time()),
|
||||||
model=request.model,
|
model=request.model,
|
||||||
prompt_tokens=0,
|
prompt_tokens=0,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
"""OpenAI chat completion response builder."""
|
"""OpenAI chat completion response builder."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
|
@ -13,6 +15,16 @@ from astrai.inference.api.protocol import (
|
||||||
)
|
)
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_UNSUPPORTED_PARAMS = (
|
||||||
|
"n",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"logit_bias",
|
||||||
|
"user",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponseBuilder(ResponseBuilder):
|
class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
def prepare(
|
def prepare(
|
||||||
|
|
@ -24,9 +36,26 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
self._model = request.model
|
self._model = request.model
|
||||||
|
|
||||||
|
for param in _UNSUPPORTED_PARAMS:
|
||||||
|
value = getattr(request, param, None)
|
||||||
|
fields = getattr(type(request), "model_fields", {})
|
||||||
|
default = fields[param].default if param in fields else None
|
||||||
|
if value is not None and value != default:
|
||||||
|
logger.warning(
|
||||||
|
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||||
|
param,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
if value is not None and value != default:
|
||||||
|
logger.warning(
|
||||||
|
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||||
|
param,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
ctx = GenContext(
|
ctx = GenContext(
|
||||||
resp_id=self._resp_id,
|
resp_id=self._resp_id,
|
||||||
created=0,
|
created=int(time.time()),
|
||||||
model=self._model,
|
model=self._model,
|
||||||
prompt_tokens=0,
|
prompt_tokens=0,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ class StopChecker:
|
||||||
class ResponseBuilder(ABC):
|
class ResponseBuilder(ABC):
|
||||||
"""Interface for protocol-specific response formatting.
|
"""Interface for protocol-specific response formatting.
|
||||||
|
|
||||||
A new protocol requires one concrete builder implementing 6 methods.
|
A new protocol requires one concrete builder implementing 5 methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ class Pipeline:
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
chunk_dir = os.path.join(self.output_dir, domain)
|
||||||
fmt = self.config.output.storage_format
|
fmt = self.config.output.storage_format
|
||||||
if fmt == "bin":
|
if fmt == "bin":
|
||||||
save_bin(chunk_dir, tensors)
|
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
|
||||||
else:
|
else:
|
||||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||||
shard_idx[domain] = idx + 1
|
shard_idx[domain] = idx + 1
|
||||||
|
|
|
||||||
|
|
@ -451,7 +451,7 @@ class TestPipeline:
|
||||||
tokenizer_path=tokenizer_dir,
|
tokenizer_path=tokenizer_dir,
|
||||||
).run()
|
).run()
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
assert os.path.exists(meta_path)
|
assert os.path.exists(meta_path)
|
||||||
with open(meta_path, "r") as f:
|
with open(meta_path, "r") as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
@ -505,7 +505,7 @@ class TestPipeline:
|
||||||
tokenizer_path=tokenizer_dir,
|
tokenizer_path=tokenizer_dir,
|
||||||
).run()
|
).run()
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
assert os.path.exists(meta_path)
|
assert os.path.exists(meta_path)
|
||||||
with open(meta_path, "r") as f:
|
with open(meta_path, "r") as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
@ -560,7 +560,7 @@ class TestPipeline:
|
||||||
tokenizer_path=tokenizer_dir,
|
tokenizer_path=tokenizer_dir,
|
||||||
).run()
|
).run()
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
assert os.path.exists(meta_path)
|
assert os.path.exists(meta_path)
|
||||||
with open(meta_path, "r") as f:
|
with open(meta_path, "r") as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue