refactor: Config序列化统一BaseConfig基类

- 新增astrai/config/base.py,提供to_dict/from_dict基类
- 统一命名:load/save → from_file/to_file
- Checkpoint.meta合并训练配置到meta.json
- sys.stderr.warn → warnings.warn
- from_file改为classmethod
This commit is contained in:
ViperEkura 2026-05-16 22:06:39 +08:00
parent d7a7f570ed
commit f91bfee33e
11 changed files with 126 additions and 84 deletions

View File

@ -5,10 +5,15 @@
```mermaid ```mermaid
classDiagram classDiagram
namespace config { namespace config {
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
}
class BaseModelConfig { class BaseModelConfig {
+Optional[str] model_type +Optional[str] model_type
+load(config_path) Self +from_file(config_path) Self
+save(config_path) +to_file(config_path)
} }
class ModelConfig { class ModelConfig {
@ -147,6 +152,7 @@ classDiagram
+int epoch +int epoch
+int iteration +int iteration
+dict extra +dict extra
+dict meta
+save(save_dir) +save(save_dir)
+load(save_dir) Checkpoint +load(save_dir) Checkpoint
} }
@ -750,6 +756,8 @@ classDiagram
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer AutoModel <|-- Transformer
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- ModelConfig BaseModelConfig <|-- ModelConfig
BaseFactory <|-- AutoModel BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory BaseFactory <|-- AttnFactory
@ -838,7 +846,7 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management | | **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization | | **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |

View File

@ -157,12 +157,13 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Checkpoint ## Checkpoint
``` ```
Checkpoint(state_dict, epoch, iteration, extra) Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt ├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
└── load(save_dir) broadcasts metadata from rank-0 └── load(save_dir) broadcasts metadata from rank-0
``` ```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`. Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern) ## TrainContextBuilder (Builder Pattern)

77
astrai/config/base.py Normal file
View File

@ -0,0 +1,77 @@
import json
from dataclasses import MISSING, dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, dict):
try:
json.dumps(v)
d[fld.name] = v
except (TypeError, ValueError):
pass
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError

View File

@ -1,12 +1,14 @@
import json import json
import sys import warnings
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
@dataclass @dataclass
class BaseModelConfig: class BaseModelConfig(BaseConfig):
"""Field-aware JSON load/save for dataclass configs. """Field-aware JSON from/to file for dataclass configs.
Subclass with additional fields. The base ``model_type`` field Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass. enables ``AutoModel`` to pick the correct subclass.
@ -14,76 +16,25 @@ class BaseModelConfig:
model_type: Optional[str] = None model_type: Optional[str] = None
def load(self, config_path: str) -> Self: @classmethod
raw: Dict[str, Any] = {} def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f: with open(config_path, "r") as f:
raw.update(json.load(f)) raw: Dict[str, Any] = json.load(f)
hints = get_type_hints(type(self)) valid = {fld.name for fld in fields(cls)}
valid = {fld.name for fld in fields(self)} for key in list(raw):
for key, value in raw.items():
if key not in valid: if key not in valid:
sys.stderr.write(f"WARNING: unknown config key '{key}'\n") warnings.warn(f"Unknown config key '{key}'")
continue del raw[key]
target_type = self._unwrap_optional(hints.get(key)) return cls.from_dict(raw)
if target_type is None:
continue
try: def to_file(self, config_path: str):
value = self._coerce(value, target_type) d = self.to_dict()
except (TypeError, ValueError): config_dict = {k: v for k, v in d.items() if v is not None}
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4) json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass @dataclass
class ModelConfig(BaseModelConfig): class ModelConfig(BaseModelConfig):

View File

@ -6,9 +6,11 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
@dataclass @dataclass
class TrainConfig: class TrainConfig(BaseConfig):
# basic setting # basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."}) model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."}) strategy: str = field(default=None, metadata={"help": "Training strategy."})

View File

@ -60,10 +60,9 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path) model_path = Path(path)
# Load config # Load config
config = ModelConfig()
config_path = model_path / "config.json" config_path = model_path / "config.json"
if config_path.exists(): if config_path.exists():
config.load(str(config_path)) config = ModelConfig.from_file(str(config_path))
else: else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
@ -89,7 +88,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
# Save config # Save config
self.config.save(str(save_path / "config.json")) self.config.to_file(str(save_path / "config.json"))
# Save weights # Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors")) st.save_file(self.state_dict(), str(save_path / "model.safetensors"))

View File

@ -17,11 +17,13 @@ class Checkpoint:
epoch: int = 0, epoch: int = 0,
iteration: int = 0, iteration: int = 0,
extra: Optional[Dict[str, Any]] = None, extra: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
): ):
self.state_dict = state_dict self.state_dict = state_dict
self.epoch = epoch self.epoch = epoch
self.iteration = iteration self.iteration = iteration
self.extra = extra or {} self.extra = extra or {}
self.meta = meta or {}
def save( def save(
self, self,
@ -38,6 +40,7 @@ class Checkpoint:
"iteration": self.iteration, "iteration": self.iteration,
"timestamp": time.time(), "timestamp": time.time(),
} }
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2) json.dump(meta, f, indent=2)

View File

@ -125,6 +125,7 @@ class CheckpointCallback(TrainCallback):
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration, iteration=context.iteration,
extra=extra, extra=extra,
meta=context.config.to_dict(),
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)

View File

@ -21,6 +21,7 @@ class TrainContext:
optimizer: Optimizer = field(default=None) optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None) scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
@ -48,6 +49,7 @@ class TrainContextBuilder:
model=self.config.model, model=self.config.model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=self.config,
) )
device = get_current_device() device = get_current_device()

View File

@ -235,10 +235,8 @@ def train(
assert os.path.exists(param_path) assert os.path.exists(param_path)
# Load config # Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")
if os.path.exists(config_path): config = ModelConfig.from_file(config_path)
config.load(config_path)
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len

View File

@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig.from_file(config_path)
model = Transformer(config) model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig.from_file(config_path)
model = Transformer(config) model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig.from_file(config_path)
original_model = Transformer(config) original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path) st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig().load(config_path) loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config) model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
loaded_config = ModelConfig().load(config_path) loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config) model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))