refactor: Config序列化统一BaseConfig基类
- 新增astrai/config/base.py,提供to_dict/from_dict基类 - 统一命名:load/save → from_file/to_file - Checkpoint.meta合并训练配置到meta.json - sys.stderr.warn → warnings.warn - from_file改为classmethod
This commit is contained in:
parent
d7a7f570ed
commit
f91bfee33e
|
|
@ -5,10 +5,15 @@
|
|||
```mermaid
|
||||
classDiagram
|
||||
namespace config {
|
||||
class BaseConfig {
|
||||
+to_dict() Dict
|
||||
+from_dict(d) Self
|
||||
}
|
||||
|
||||
class BaseModelConfig {
|
||||
+Optional[str] model_type
|
||||
+load(config_path) Self
|
||||
+save(config_path)
|
||||
+from_file(config_path) Self
|
||||
+to_file(config_path)
|
||||
}
|
||||
|
||||
class ModelConfig {
|
||||
|
|
@ -147,6 +152,7 @@ classDiagram
|
|||
+int epoch
|
||||
+int iteration
|
||||
+dict extra
|
||||
+dict meta
|
||||
+save(save_dir)
|
||||
+load(save_dir) Checkpoint
|
||||
}
|
||||
|
|
@ -750,6 +756,8 @@ classDiagram
|
|||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoModel <|-- Transformer
|
||||
BaseConfig <|-- BaseModelConfig
|
||||
BaseConfig <|-- TrainConfig
|
||||
BaseModelConfig <|-- ModelConfig
|
||||
BaseFactory <|-- AutoModel
|
||||
BaseFactory <|-- AttnFactory
|
||||
|
|
@ -838,7 +846,7 @@ classDiagram
|
|||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
|
|
|
|||
|
|
@ -157,12 +157,13 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
|||
## Checkpoint
|
||||
|
||||
```
|
||||
Checkpoint(state_dict, epoch, iteration, extra)
|
||||
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
|
||||
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
||||
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
|
||||
└── load(save_dir) broadcasts metadata from rank-0
|
||||
```
|
||||
|
||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
|
||||
|
||||
## TrainContextBuilder (Builder Pattern)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
import json
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {}
|
||||
for fld in fields(self):
|
||||
v = getattr(self, fld.name)
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
d[fld.name] = v
|
||||
elif v is None:
|
||||
d[fld.name] = None
|
||||
elif isinstance(v, dict):
|
||||
try:
|
||||
json.dumps(v)
|
||||
d[fld.name] = v
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> Self:
|
||||
hints = get_type_hints(cls)
|
||||
inst = cls.__new__(cls)
|
||||
for fld in fields(cls):
|
||||
if fld.name in d:
|
||||
v = d[fld.name]
|
||||
target = cls._unwrap_optional(hints.get(fld.name))
|
||||
if target is not None:
|
||||
try:
|
||||
v = cls._coerce(v, target)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
object.__setattr__(inst, fld.name, v)
|
||||
elif fld.default is not MISSING:
|
||||
object.__setattr__(inst, fld.name, fld.default)
|
||||
elif fld.default_factory is not MISSING:
|
||||
object.__setattr__(inst, fld.name, fld.default_factory())
|
||||
else:
|
||||
object.__setattr__(inst, fld.name, None)
|
||||
return inst
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_optional(tp) -> Optional[type]:
|
||||
if tp is None:
|
||||
return None
|
||||
origin = getattr(tp, "__origin__", None)
|
||||
if origin is not None:
|
||||
args = getattr(tp, "__args__", ())
|
||||
non_none = [a for a in args if a is not type(None)]
|
||||
return non_none[0] if non_none else None
|
||||
return tp
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value: Any, target_type: type) -> Any:
|
||||
if target_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if (
|
||||
target_type is int
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return int(value)
|
||||
if (
|
||||
target_type is float
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return float(value)
|
||||
if target_type is str and isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
raise TypeError
|
||||
|
|
@ -1,12 +1,14 @@
|
|||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
from typing import Any, Dict, Optional, Self
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
"""Field-aware JSON load/save for dataclass configs.
|
||||
class BaseModelConfig(BaseConfig):
|
||||
"""Field-aware JSON from/to file for dataclass configs.
|
||||
|
||||
Subclass with additional fields. The base ``model_type`` field
|
||||
enables ``AutoModel`` to pick the correct subclass.
|
||||
|
|
@ -14,76 +16,25 @@ class BaseModelConfig:
|
|||
|
||||
model_type: Optional[str] = None
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
raw: Dict[str, Any] = {}
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str) -> Self:
|
||||
with open(config_path, "r") as f:
|
||||
raw.update(json.load(f))
|
||||
raw: Dict[str, Any] = json.load(f)
|
||||
|
||||
hints = get_type_hints(type(self))
|
||||
valid = {fld.name for fld in fields(self)}
|
||||
for key, value in raw.items():
|
||||
valid = {fld.name for fld in fields(cls)}
|
||||
for key in list(raw):
|
||||
if key not in valid:
|
||||
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
|
||||
continue
|
||||
warnings.warn(f"Unknown config key '{key}'")
|
||||
del raw[key]
|
||||
|
||||
target_type = self._unwrap_optional(hints.get(key))
|
||||
if target_type is None:
|
||||
continue
|
||||
return cls.from_dict(raw)
|
||||
|
||||
try:
|
||||
value = self._coerce(value, target_type)
|
||||
except (TypeError, ValueError):
|
||||
sys.stderr.write(
|
||||
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
|
||||
)
|
||||
continue
|
||||
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str):
|
||||
config_dict: Dict[str, Any] = {}
|
||||
for fld in fields(self):
|
||||
v = getattr(self, fld.name)
|
||||
if v is not None:
|
||||
config_dict[fld.name] = v
|
||||
def to_file(self, config_path: str):
|
||||
d = self.to_dict()
|
||||
config_dict = {k: v for k, v in d.items() if v is not None}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_optional(tp: type) -> Optional[type]:
|
||||
if tp is None:
|
||||
return None
|
||||
origin = getattr(tp, "__origin__", None)
|
||||
if origin is not None:
|
||||
args = getattr(tp, "__args__", ())
|
||||
non_none = [a for a in args if a is not type(None)]
|
||||
return non_none[0] if non_none else None
|
||||
return tp
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value: Any, target_type: type) -> Any:
|
||||
if target_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if (
|
||||
target_type is int
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return int(value)
|
||||
if (
|
||||
target_type is float
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return float(value)
|
||||
if target_type is str and isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
raise TypeError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(BaseModelConfig):
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||
|
|
|
|||
|
|
@ -60,10 +60,9 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
model_path = Path(path)
|
||||
|
||||
# Load config
|
||||
config = ModelConfig()
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
config.load(str(config_path))
|
||||
config = ModelConfig.from_file(str(config_path))
|
||||
else:
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
|
|
@ -89,7 +88,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
self.config.save(str(save_path / "config.json"))
|
||||
self.config.to_file(str(save_path / "config.json"))
|
||||
|
||||
# Save weights
|
||||
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||
|
|
|
|||
|
|
@ -17,11 +17,13 @@ class Checkpoint:
|
|||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.state_dict = state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.extra = extra or {}
|
||||
self.meta = meta or {}
|
||||
|
||||
def save(
|
||||
self,
|
||||
|
|
@ -38,6 +40,7 @@ class Checkpoint:
|
|||
"iteration": self.iteration,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
meta.update(self.meta)
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ class CheckpointCallback(TrainCallback):
|
|||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
meta=context.config.to_dict(),
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class TrainContext:
|
|||
optimizer: Optimizer = field(default=None)
|
||||
scheduler: LRScheduler = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
config: TrainConfig = field(default=None)
|
||||
|
||||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
|
|
@ -48,6 +49,7 @@ class TrainContextBuilder:
|
|||
model=self.config.model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
device = get_current_device()
|
||||
|
|
|
|||
|
|
@ -235,10 +235,8 @@ def train(
|
|||
assert os.path.exists(param_path)
|
||||
|
||||
# Load config
|
||||
config = ModelConfig()
|
||||
config_path = os.path.join(param_path, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
config.load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
|
||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
config = ModelConfig.from_file(config_path)
|
||||
original_model = Transformer(config)
|
||||
|
||||
st.save_file(original_model.state_dict(), model_path)
|
||||
|
||||
loaded_config = ModelConfig().load(config_path)
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
loaded_config = ModelConfig().load(config_path)
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue