From f91bfee33e05a96ce462ff7c3c9f05531f7f1819 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 16 May 2026 22:06:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20Config=E5=BA=8F=E5=88=97=E5=8C=96?= =?UTF-8?q?=E7=BB=9F=E4=B8=80BaseConfig=E5=9F=BA=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增astrai/config/base.py,提供to_dict/from_dict基类 - 统一命名:load/save → from_file/to_file - Checkpoint.meta合并训练配置到meta.json - sys.stderr.warn → warnings.warn - from_file改为classmethod --- assets/docs/architecture.md | 14 ++++-- assets/docs/training.md | 7 +-- astrai/config/base.py | 77 +++++++++++++++++++++++++++++ astrai/config/model_config.py | 83 +++++++------------------------- astrai/config/train_config.py | 4 +- astrai/model/automodel.py | 5 +- astrai/serialization.py | 3 ++ astrai/trainer/train_callback.py | 1 + astrai/trainer/train_context.py | 2 + scripts/tools/train.py | 4 +- tests/module/test_tie_weight.py | 10 ++-- 11 files changed, 126 insertions(+), 84 deletions(-) create mode 100644 astrai/config/base.py diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index b1208ea..6c955ae 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -5,10 +5,15 @@ ```mermaid classDiagram namespace config { + class BaseConfig { + +to_dict() Dict + +from_dict(d) Self + } + class BaseModelConfig { +Optional[str] model_type - +load(config_path) Self - +save(config_path) + +from_file(config_path) Self + +to_file(config_path) } class ModelConfig { @@ -147,6 +152,7 @@ classDiagram +int epoch +int iteration +dict extra + +dict meta +save(save_dir) +load(save_dir) Checkpoint } @@ -750,6 +756,8 @@ classDiagram ParallelModel <|-- RowParallelLinear ParallelModel <|-- ColumnParallelLinear AutoModel <|-- Transformer + BaseConfig <|-- BaseModelConfig + BaseConfig <|-- TrainConfig BaseModelConfig <|-- ModelConfig BaseFactory <|-- AutoModel BaseFactory <|-- AttnFactory @@ -838,7 +846,7 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| -| **astrai.config** | ModelConfig, TrainConfig | Configuration management | +| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | diff --git a/assets/docs/training.md b/assets/docs/training.md index a7aa1b0..ef979a0 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -157,12 +157,13 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. ## Checkpoint ``` -Checkpoint(state_dict, epoch, iteration, extra) - ├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt +Checkpoint(state_dict, epoch, iteration, extra, meta) + ├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt └── load(save_dir) broadcasts metadata from rank-0 ``` -Optimizer/scheduler state persisted by default via `Checkpoint.extra`. +Optimizer/scheduler state persisted by default via `Checkpoint.extra`. +Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`. ## TrainContextBuilder (Builder Pattern) diff --git a/astrai/config/base.py b/astrai/config/base.py new file mode 100644 index 0000000..0c6182c --- /dev/null +++ b/astrai/config/base.py @@ -0,0 +1,77 @@ +import json +from dataclasses import MISSING, dataclass, fields +from typing import Any, Dict, Optional, Self, get_type_hints + + +@dataclass +class BaseConfig: + def to_dict(self) -> Dict[str, Any]: + d = {} + for fld in fields(self): + v = getattr(self, fld.name) + if isinstance(v, (str, int, float, bool)): + d[fld.name] = v + elif v is None: + d[fld.name] = None + elif isinstance(v, dict): + try: + json.dumps(v) + d[fld.name] = v + except (TypeError, ValueError): + pass + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> Self: + hints = get_type_hints(cls) + inst = cls.__new__(cls) + for fld in fields(cls): + if fld.name in d: + v = d[fld.name] + target = cls._unwrap_optional(hints.get(fld.name)) + if target is not None: + try: + v = cls._coerce(v, target) + except (TypeError, ValueError): + pass + object.__setattr__(inst, fld.name, v) + elif fld.default is not MISSING: + object.__setattr__(inst, fld.name, fld.default) + elif fld.default_factory is not MISSING: + object.__setattr__(inst, fld.name, fld.default_factory()) + else: + object.__setattr__(inst, fld.name, None) + return inst + + @staticmethod + def _unwrap_optional(tp) -> Optional[type]: + if tp is None: + return None + origin = getattr(tp, "__origin__", None) + if origin is not None: + args = getattr(tp, "__args__", ()) + non_none = [a for a in args if a is not type(None)] + return non_none[0] if non_none else None + return tp + + @staticmethod + def _coerce(value: Any, target_type: type) -> Any: + if target_type is bool and isinstance(value, bool): + return value + if ( + target_type is int + and isinstance(value, (int, float)) + and not isinstance(value, bool) + ): + return int(value) + if ( + target_type is float + and isinstance(value, (int, float)) + and not isinstance(value, bool) + ): + return float(value) + if target_type is str and isinstance(value, str): + return value + if isinstance(value, target_type): + return value + raise TypeError diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index 0e02428..3c13920 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -1,12 +1,14 @@ import json -import sys +import warnings from dataclasses import dataclass, fields -from typing import Any, Dict, Optional, Self, get_type_hints +from typing import Any, Dict, Optional, Self + +from astrai.config.base import BaseConfig @dataclass -class BaseModelConfig: - """Field-aware JSON load/save for dataclass configs. +class BaseModelConfig(BaseConfig): + """Field-aware JSON from/to file for dataclass configs. Subclass with additional fields. The base ``model_type`` field enables ``AutoModel`` to pick the correct subclass. @@ -14,76 +16,25 @@ class BaseModelConfig: model_type: Optional[str] = None - def load(self, config_path: str) -> Self: - raw: Dict[str, Any] = {} + @classmethod + def from_file(cls, config_path: str) -> Self: with open(config_path, "r") as f: - raw.update(json.load(f)) + raw: Dict[str, Any] = json.load(f) - hints = get_type_hints(type(self)) - valid = {fld.name for fld in fields(self)} - for key, value in raw.items(): + valid = {fld.name for fld in fields(cls)} + for key in list(raw): if key not in valid: - sys.stderr.write(f"WARNING: unknown config key '{key}'\n") - continue + warnings.warn(f"Unknown config key '{key}'") + del raw[key] - target_type = self._unwrap_optional(hints.get(key)) - if target_type is None: - continue + return cls.from_dict(raw) - try: - value = self._coerce(value, target_type) - except (TypeError, ValueError): - sys.stderr.write( - f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n" - ) - continue - - setattr(self, key, value) - - return self - - def save(self, config_path: str): - config_dict: Dict[str, Any] = {} - for fld in fields(self): - v = getattr(self, fld.name) - if v is not None: - config_dict[fld.name] = v + def to_file(self, config_path: str): + d = self.to_dict() + config_dict = {k: v for k, v in d.items() if v is not None} with open(config_path, "w") as f: json.dump(config_dict, f, indent=4) - @staticmethod - def _unwrap_optional(tp: type) -> Optional[type]: - if tp is None: - return None - origin = getattr(tp, "__origin__", None) - if origin is not None: - args = getattr(tp, "__args__", ()) - non_none = [a for a in args if a is not type(None)] - return non_none[0] if non_none else None - return tp - - @staticmethod - def _coerce(value: Any, target_type: type) -> Any: - if target_type is bool and isinstance(value, bool): - return value - if ( - target_type is int - and isinstance(value, (int, float)) - and not isinstance(value, bool) - ): - return int(value) - if ( - target_type is float - and isinstance(value, (int, float)) - and not isinstance(value, bool) - ): - return float(value) - if target_type is str and isinstance(value, str): - return value - if isinstance(value, target_type): - return value - raise TypeError - @dataclass class ModelConfig(BaseModelConfig): diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index fdfdb9b..801edd6 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -6,9 +6,11 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Dataset +from astrai.config.base import BaseConfig + @dataclass -class TrainConfig: +class TrainConfig(BaseConfig): # basic setting model: nn.Module = field(default=None, metadata={"help": "Model for training."}) strategy: str = field(default=None, metadata={"help": "Training strategy."}) diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index d86a523..22f9555 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -60,10 +60,9 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module): model_path = Path(path) # Load config - config = ModelConfig() config_path = model_path / "config.json" if config_path.exists(): - config.load(str(config_path)) + config = ModelConfig.from_file(str(config_path)) else: raise FileNotFoundError(f"Config file not found: {config_path}") @@ -89,7 +88,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module): save_path.mkdir(parents=True, exist_ok=True) # Save config - self.config.save(str(save_path / "config.json")) + self.config.to_file(str(save_path / "config.json")) # Save weights st.save_file(self.state_dict(), str(save_path / "model.safetensors")) diff --git a/astrai/serialization.py b/astrai/serialization.py index dfd9cf1..103fada 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -17,11 +17,13 @@ class Checkpoint: epoch: int = 0, iteration: int = 0, extra: Optional[Dict[str, Any]] = None, + meta: Optional[Dict[str, Any]] = None, ): self.state_dict = state_dict self.epoch = epoch self.iteration = iteration self.extra = extra or {} + self.meta = meta or {} def save( self, @@ -38,6 +40,7 @@ class Checkpoint: "iteration": self.iteration, "timestamp": time.time(), } + meta.update(self.meta) with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 0fcef47..8c654d4 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -125,6 +125,7 @@ class CheckpointCallback(TrainCallback): epoch=context.epoch, iteration=context.iteration, extra=extra, + meta=context.config.to_dict(), ) context.checkpoint.save(save_path) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 0350144..a74b952 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -21,6 +21,7 @@ class TrainContext: optimizer: Optimizer = field(default=None) scheduler: LRScheduler = field(default=None) checkpoint: Checkpoint = field(default=None) + config: TrainConfig = field(default=None) epoch: int = field(default=0) iteration: int = field(default=0) @@ -48,6 +49,7 @@ class TrainContextBuilder: model=self.config.model, world_size=get_world_size(), rank=get_rank(), + config=self.config, ) device = get_current_device() diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 795aa0e..4a3412f 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -235,10 +235,8 @@ def train( assert os.path.exists(param_path) # Load config - config = ModelConfig() config_path = os.path.join(param_path, "config.json") - if os.path.exists(config_path): - config.load(config_path) + config = ModelConfig.from_file(config_path) if window_size is None: window_size = config.max_len diff --git a/tests/module/test_tie_weight.py b/tests/module/test_tie_weight.py index 73b763b..c0d0aa8 100644 --- a/tests/module/test_tie_weight.py +++ b/tests/module/test_tie_weight.py @@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - config = ModelConfig().load(config_path) + config = ModelConfig.from_file(config_path) model = Transformer(config) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) @@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - config = ModelConfig().load(config_path) + config = ModelConfig.from_file(config_path) model = Transformer(config) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) @@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - config = ModelConfig().load(config_path) + config = ModelConfig.from_file(config_path) original_model = Transformer(config) st.save_file(original_model.state_dict(), model_path) - loaded_config = ModelConfig().load(config_path) + loaded_config = ModelConfig.from_file(config_path) model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path)) @@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - loaded_config = ModelConfig().load(config_path) + loaded_config = ModelConfig.from_file(config_path) model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path))