AstrAI/astrai/config/model_config.py

71 lines
1.9 KiB
Python

import json
import warnings
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
@dataclass
class BaseModelConfig(BaseConfig):
"""Field-aware JSON from/to file for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None
@classmethod
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
if key not in valid:
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
return cls.from_dict(raw)
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@dataclass
class ModelConfig(BaseModelConfig):
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# attention
attn_type: str = "gqa"
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
# MoE
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None