71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
import json
|
|
import warnings
|
|
from dataclasses import dataclass, fields
|
|
from typing import Any, Dict, Optional, Self
|
|
|
|
from astrai.config.base import BaseConfig
|
|
|
|
|
|
@dataclass
|
|
class BaseModelConfig(BaseConfig):
|
|
"""Field-aware JSON from/to file for dataclass configs.
|
|
|
|
Subclass with additional fields. The base ``model_type`` field
|
|
enables ``AutoModel`` to pick the correct subclass.
|
|
"""
|
|
|
|
model_type: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_file(cls, config_path: str) -> Self:
|
|
with open(config_path, "r") as f:
|
|
raw: Dict[str, Any] = json.load(f)
|
|
|
|
valid = {fld.name for fld in fields(cls)}
|
|
for key in list(raw):
|
|
if key not in valid:
|
|
warnings.warn(f"Unknown config key '{key}'")
|
|
del raw[key]
|
|
|
|
return cls.from_dict(raw)
|
|
|
|
def to_file(self, config_path: str):
|
|
d = self.to_dict()
|
|
config_dict = {k: v for k, v in d.items() if v is not None}
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f, indent=4)
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig(BaseModelConfig):
|
|
vocab_size: Optional[int] = None
|
|
dim: Optional[int] = None
|
|
|
|
n_layers: Optional[int] = None
|
|
norm_eps: Optional[float] = None
|
|
dim_ffn: Optional[int] = None
|
|
tie_weight: Optional[bool] = None
|
|
|
|
# RoPE
|
|
max_len: Optional[int] = None
|
|
rope_theta: Optional[float] = None
|
|
|
|
# attention
|
|
attn_type: str = "gqa"
|
|
n_heads: Optional[int] = None
|
|
n_kv_heads: Optional[int] = None
|
|
use_qk_norm: Optional[bool] = None
|
|
use_gated_attention: Optional[bool] = None
|
|
|
|
# MLA
|
|
kv_lora_rank: Optional[int] = None
|
|
qk_nope_head_dim: Optional[int] = None
|
|
qk_rope_head_dim: Optional[int] = None
|
|
|
|
# MoE
|
|
ffn_type: str = "mlp"
|
|
n_routed_experts: Optional[int] = None
|
|
n_shared_experts: Optional[int] = None
|
|
n_activated_experts: Optional[int] = None
|
|
moe_topk_method: Optional[str] = None
|