93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
import json
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional, Self
|
|
|
|
from astrai.config.base import BaseConfig
|
|
from astrai.factory import BaseFactory
|
|
|
|
|
|
class ConfigFactory(BaseFactory[BaseConfig]):
|
|
"""Factory that dispatches config classes by ``model_type``."""
|
|
|
|
@classmethod
|
|
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
|
model_type = raw.get("model_type") or "autoregressive_lm"
|
|
config_cls = cls.get_component_class(model_type)
|
|
return config_cls.from_dict(raw)
|
|
|
|
|
|
@dataclass
|
|
class BaseModelConfig(BaseConfig):
|
|
"""Base config with ``model_type`` dispatch and file I/O."""
|
|
|
|
model_type: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_file(cls, config_path: str) -> Self:
|
|
with open(config_path, "r") as f:
|
|
raw: Dict[str, Any] = json.load(f)
|
|
return cls.from_dict(raw)
|
|
|
|
def to_file(self, config_path: str):
|
|
d = self.to_dict()
|
|
config_dict = {k: v for k, v in d.items() if v is not None}
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f, indent=4)
|
|
|
|
|
|
@dataclass
|
|
@ConfigFactory.register("autoregressive_lm")
|
|
class AutoRegressiveLMConfig(BaseModelConfig):
|
|
"""Configuration for autoregressive language model."""
|
|
|
|
vocab_size: Optional[int] = None
|
|
dim: Optional[int] = None
|
|
n_layers: Optional[int] = None
|
|
norm_eps: Optional[float] = None
|
|
dim_ffn: Optional[int] = None
|
|
tie_weight: Optional[bool] = None
|
|
|
|
max_len: Optional[int] = None
|
|
rope_theta: Optional[float] = None
|
|
rope_scaling: Optional[dict] = None
|
|
|
|
attn_type: str = "gqa"
|
|
n_heads: Optional[int] = None
|
|
n_kv_heads: Optional[int] = None
|
|
use_qk_norm: Optional[bool] = None
|
|
use_gated_attention: Optional[bool] = None
|
|
|
|
kv_lora_rank: Optional[int] = None
|
|
qk_nope_head_dim: Optional[int] = None
|
|
qk_rope_head_dim: Optional[int] = None
|
|
|
|
ffn_type: str = "mlp"
|
|
n_routed_experts: Optional[int] = None
|
|
n_shared_experts: Optional[int] = None
|
|
n_activated_experts: Optional[int] = None
|
|
topk_method: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
@ConfigFactory.register("embedding")
|
|
class EncoderConfig(BaseModelConfig):
|
|
"""Configuration for embedding encoder model."""
|
|
|
|
vocab_size: Optional[int] = None
|
|
dim: Optional[int] = None
|
|
n_layers: Optional[int] = None
|
|
norm_eps: Optional[float] = None
|
|
dim_ffn: Optional[int] = None
|
|
|
|
max_len: Optional[int] = None
|
|
rope_theta: Optional[float] = None
|
|
rope_scaling: Optional[dict] = None
|
|
|
|
n_heads: Optional[int] = None
|
|
n_kv_heads: Optional[int] = None
|
|
use_qk_norm: Optional[bool] = None
|
|
use_gated_attention: Optional[bool] = None
|
|
|
|
pooling_type: Optional[str] = None
|
|
normalize_embeddings: Optional[bool] = None
|