refactor: Config序列化统一BaseConfig基类

- 新增astrai/config/base.py,提供to_dict/from_dict基类
- 统一命名:load/save → from_file/to_file
- Checkpoint.meta合并训练配置到meta.json
- sys.stderr.warn → warnings.warn
- from_file改为classmethod
This commit is contained in:
ViperEkura 2026-05-16 22:06:39 +08:00
parent d7a7f570ed
commit f91bfee33e
11 changed files with 126 additions and 84 deletions

View File

@ -5,10 +5,15 @@
```mermaid
classDiagram
namespace config {
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
}
class BaseModelConfig {
+Optional[str] model_type
+load(config_path) Self
+save(config_path)
+from_file(config_path) Self
+to_file(config_path)
}
class ModelConfig {
@ -147,6 +152,7 @@ classDiagram
+int epoch
+int iteration
+dict extra
+dict meta
+save(save_dir)
+load(save_dir) Checkpoint
}
@ -750,6 +756,8 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- ModelConfig
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
@ -838,7 +846,7 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |

View File

@ -157,12 +157,13 @@ Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra)
├── save(save_dir) rank-0 only: meta.json + state_dict.safetensors + optional extra.pt
Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
└── load(save_dir) broadcasts metadata from rank-0
```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern)

77
astrai/config/base.py Normal file
View File

@ -0,0 +1,77 @@
import json
from dataclasses import MISSING, dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, dict):
try:
json.dumps(v)
d[fld.name] = v
except (TypeError, ValueError):
pass
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError

View File

@ -1,12 +1,14 @@
import json
import sys
import warnings
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
@dataclass
class BaseModelConfig:
"""Field-aware JSON load/save for dataclass configs.
class BaseModelConfig(BaseConfig):
"""Field-aware JSON from/to file for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
@ -14,76 +16,25 @@ class BaseModelConfig:
model_type: Optional[str] = None
def load(self, config_path: str) -> Self:
raw: Dict[str, Any] = {}
@classmethod
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw.update(json.load(f))
raw: Dict[str, Any] = json.load(f)
hints = get_type_hints(type(self))
valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
if key not in valid:
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
continue
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
return cls.from_dict(raw)
try:
value = self._coerce(value, target_type)
except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass
class ModelConfig(BaseModelConfig):

View File

@ -6,9 +6,11 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
@dataclass
class TrainConfig:
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})

View File

@ -60,10 +60,9 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path)
# Load config
config = ModelConfig()
config_path = model_path / "config.json"
if config_path.exists():
config.load(str(config_path))
config = ModelConfig.from_file(str(config_path))
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
@ -89,7 +88,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
save_path.mkdir(parents=True, exist_ok=True)
# Save config
self.config.save(str(save_path / "config.json"))
self.config.to_file(str(save_path / "config.json"))
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))

View File

@ -17,11 +17,13 @@ class Checkpoint:
epoch: int = 0,
iteration: int = 0,
extra: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.extra = extra or {}
self.meta = meta or {}
def save(
self,
@ -38,6 +40,7 @@ class Checkpoint:
"iteration": self.iteration,
"timestamp": time.time(),
}
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)

View File

@ -125,6 +125,7 @@ class CheckpointCallback(TrainCallback):
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)

View File

@ -21,6 +21,7 @@ class TrainContext:
optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
@ -48,6 +49,7 @@ class TrainContextBuilder:
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
config=self.config,
)
device = get_current_device()

View File

@ -235,10 +235,8 @@ def train(
assert os.path.exists(param_path)
# Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json")
if os.path.exists(config_path):
config.load(config_path)
config = ModelConfig.from_file(config_path)
if window_size is None:
window_size = config.max_len

View File

@ -50,7 +50,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -68,7 +68,7 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -94,12 +94,12 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
config = ModelConfig.from_file(config_path)
original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig().load(config_path)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
@ -112,7 +112,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
loaded_config = ModelConfig().load(config_path)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))