110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
"""Pipeline configuration for JSONL preprocessing.
|
|
|
|
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
|
|
modes, both driven declaratively through ``input.sections`` or
|
|
``input.sources``.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional
|
|
|
|
from astrai.config.base import BaseConfig
|
|
|
|
|
|
@dataclass
|
|
class InputConfig(BaseConfig):
|
|
"""Declarative input mapping.
|
|
|
|
Single-output mode (backward-compatible)::
|
|
|
|
{"input": {"sections": [{"field": "messages", ...}]}}
|
|
|
|
Multi-output mode (DPO / GRPO)::
|
|
|
|
{"input": {"sources": {
|
|
"chosen": {"sections": [{"field": "chosen", ...}]},
|
|
"rejected": {"sections": [{"field": "rejected", ...}]},
|
|
}}}
|
|
"""
|
|
|
|
sections: Optional[List[Dict]] = None
|
|
sources: Optional[Dict[str, Dict]] = None
|
|
|
|
|
|
@dataclass
|
|
class ProcessingConfig(BaseConfig):
|
|
"""Processing configuration.
|
|
|
|
Parameters
|
|
----------
|
|
max_seq_len : int
|
|
Maximum sequence length (default: 2048).
|
|
min_chars : int
|
|
Minimum number of characters to keep (default: 50).
|
|
max_chars : int
|
|
Maximum number of characters to keep (default: 2_000_000).
|
|
max_items : Optional[int]
|
|
Maximum number of items to process (default: None, unlimited).
|
|
packing_strategy : str
|
|
How to pack sequences into a contiguous stream.
|
|
|
|
- ``"simple"``: sequential concatenation (default, backward compatible).
|
|
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
|
|
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
|
|
max_packed_len : int
|
|
Maximum length of a packed bin. Sequences longer than this are
|
|
truncated or split depending on ``packing_strategy`` (default: 8192).
|
|
truncation_mode : str
|
|
How to truncate sequences longer than ``max_packed_len``.
|
|
|
|
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
|
|
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
|
|
"""
|
|
|
|
max_seq_len: int = 2048
|
|
min_chars: int = 50
|
|
max_chars: int = 2_000_000
|
|
max_items: Optional[int] = None
|
|
packing_strategy: str = "simple"
|
|
max_packed_len: int = 8192
|
|
truncation_mode: str = "keep_start"
|
|
|
|
|
|
@dataclass
|
|
class OutputConfig(BaseConfig):
|
|
"""Output configuration.
|
|
|
|
Parameters
|
|
----------
|
|
domain_key : Optional[str]
|
|
Domain key for the output store (default: None).
|
|
storage_format : str
|
|
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
|
|
max_tokens_per_shard : int
|
|
Maximum tokens per shard before splitting (default: 100_000_000).
|
|
dtype : Dict[str, str]
|
|
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
|
|
position_ids_mode : Optional[str]
|
|
How to compute position_ids in packed sequences.
|
|
|
|
- ``None`` / ``"none"``: do not generate (backward compatible).
|
|
- ``"doc_reset"``: reset to 0 at each document boundary.
|
|
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
|
"""
|
|
|
|
domain_key: Optional[str] = None
|
|
storage_format: str = "bin"
|
|
max_tokens_per_shard: int = 100_000_000
|
|
dtype: Dict[str, str] = field(default_factory=dict)
|
|
position_ids_mode: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class PipelineConfig(BaseConfig):
|
|
version: int = 1
|
|
input: InputConfig = field(default_factory=InputConfig)
|
|
mask: Dict[str, str] = field(default_factory=dict)
|
|
mask_default: str = "mask"
|
|
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
|
output: OutputConfig = field(default_factory=OutputConfig)
|