refactor : 清理工厂和配置系统中的死代码与冗余抽象

- 删除 Registry 中未使用的 category/priority 字段,_entries 简化为直接存储类引用
- 修正 __init_subclass__ 避免叶子类(AutoRegressiveLM 等)创建空注册表
- 删除 5 个工厂的薄 create() 覆写,统一使用 BaseFactory.create(name, *args, **kwargs)
- 删除 3 处零调用的 available_types/available_strategies 别名死代码
- 删除零调用的 BaseModelConfig.to_file 死代码
- 将 BaseConfig.from_json/to_json 重命名为 from_file/to_file,消除与子类重复
- 移除两个 inference builder 中总是被覆写的 prompt_tokens=0
This commit is contained in:
ViperEkura 2026-06-07 11:39:50 +08:00
parent e7b18b7c03
commit 6ae1828449
20 changed files with 31 additions and 114 deletions

View File

@ -89,10 +89,10 @@ class BaseConfig:
raise TypeError
@classmethod
def from_json(cls, path: Union[str, Path]) -> Self:
def from_file(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]):
def to_file(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -1,6 +1,5 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, Optional, Self
from typing import Any, Dict, Optional
from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
@ -22,18 +21,6 @@ class BaseModelConfig(BaseConfig):
model_type: Optional[str] = None
@classmethod
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
return cls.from_dict(raw)
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@dataclass
@ConfigFactory.register("autoregressive_lm")

View File

@ -136,20 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
dataset = DatasetFactory.create("custom", window_size, stride)
"""
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling
stride: Stride between consecutive samples
Returns:
Dataset instance
"""
return super().create(train_type, window_size, stride)
@classmethod
def load(
cls,
@ -179,11 +165,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
return dataset
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
return cls.list_registered()
@DatasetFactory.register("seq")
class SEQDataset(BaseDataset):

View File

@ -8,8 +8,6 @@ from typing import (
Dict,
ForwardRef,
Generic,
Optional,
Tuple,
Type,
TypeVar,
)
@ -56,21 +54,19 @@ class BaseFactory(ABC, Generic[T]):
unrelated parameters.
"""
_entries: Dict[str, Tuple[Type, Optional[str], int]]
_entries: Dict[str, Type[T]]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._entries = {}
for orig_base in getattr(cls, "__orig_bases__", ()):
if _get_origin(orig_base) is BaseFactory:
(arg,) = _get_args(orig_base)
cls._entries = {}
cls._component_base = _resolve_type(arg, cls)
return
@classmethod
def register(
cls, name: str, category: Optional[str] = None, priority: int = 0
) -> Callable[[Type[T]], Type[T]]:
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class.
Validates that the decorated class inherits from the generic
@ -81,7 +77,7 @@ class BaseFactory(ABC, Generic[T]):
cls._validate_component(component_cls)
if name in cls._entries:
raise ValueError(f"Component '{name}' is already registered")
cls._entries[name] = (component_cls, category, priority)
cls._entries[name] = component_cls
return component_cls
return decorator
@ -96,7 +92,7 @@ class BaseFactory(ABC, Generic[T]):
raise ValueError(
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
)
component_cls = entry[0]
component_cls = entry
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
@ -130,7 +126,7 @@ class BaseFactory(ABC, Generic[T]):
raise ValueError(
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
)
return entry[0]
return entry
@classmethod
def list_registered(cls) -> list:

View File

@ -42,7 +42,6 @@ class AnthropicResponseBuilder(ResponseBuilder):
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=int(time.time()),
model=request.model,
prompt_tokens=0,
)
stop_sequences = getattr(request, "stop_sequences", None) or []
return prompt, ctx, stop_sequences

View File

@ -86,7 +86,6 @@ class OpenAIResponseBuilder(ResponseBuilder):
resp_id=self._resp_id,
created=int(time.time()),
model=self._model,
prompt_tokens=0,
)
stop = request.stop
stop_sequences = (

View File

@ -35,7 +35,7 @@ class GenContext:
resp_id: str
created: int
model: str
prompt_tokens: int
prompt_tokens: int = 0
completion_tokens: int = 0

View File

@ -24,9 +24,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
class AttnFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, attn_type: str, **kwargs) -> nn.Module:
return super().create(attn_type, **kwargs)
pass
@AttnFactory.register("gqa")

View File

@ -8,9 +8,7 @@ from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
pass
@FFNFactory.register("mlp")

View File

@ -44,7 +44,7 @@ class Pipeline:
Usage::
config = PipelineConfig.from_json("sft_pipeline.json")
config = PipelineConfig.from_file("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
"""

View File

@ -31,7 +31,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage:
@SchedulerFactory.register("custom")
@ -41,27 +40,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
"""
@classmethod
def create(
cls, optimizer, schedule_type: str = "none", **kwargs
) -> "BaseScheduler":
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
"""
return super().create(schedule_type, optimizer, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return cls.list_registered()
# ----------- Scheduler implementations -----------

View File

@ -127,26 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
strategy = StrategyFactory.create("custom", model, device)
"""
@classmethod
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
model: Model instance for the strategy
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
"""
return super().create(train_type, model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
return cls.list_registered()
# ============== Strategy Classes ==============
# All strategies are registered at class definition time using the decorator

View File

@ -172,8 +172,8 @@ class TrainContextBuilder:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
cfg.strategy,
model=context.model,
train_type=cfg.strategy,
device=device,
executor=executor,
model_fn=cfg.model_fn,

View File

@ -24,7 +24,7 @@ def main():
)
args = parser.parse_args()
config = PipelineConfig.from_json(args.config)
config = PipelineConfig.from_file(args.config)
Pipeline(
config=config,

View File

@ -231,7 +231,8 @@ def create_optimizer(model, **kwargs) -> optim.Optimizer:
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.create(optimizer, **kwargs)
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(schedule_type, optimizer, **kwargs)
def compute_total_steps(

View File

@ -53,15 +53,15 @@ def test_to_dict_roundtrip():
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(temp_dir):
def test_to_file_from_file(temp_dir):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
config.to_file(path)
loaded = PipelineConfig.from_file(path)
assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"}
@ -69,8 +69,8 @@ def test_to_json_from_json(temp_dir):
def test_dpo_config_roundtrip(temp_dir):
config = make_dpo_chat_config()
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
config.to_file(path)
loaded = PipelineConfig.from_file(path)
assert loaded.input.sources is not None
assert "chosen" in loaded.input.sources
assert "rejected" in loaded.input.sources

View File

@ -65,7 +65,7 @@ def create_train_config(
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
return TrainConfig(

View File

@ -102,7 +102,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(
@ -136,7 +136,7 @@ def test_callback_integration(base_test_env, random_dataset):
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(

View File

@ -16,7 +16,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(

View File

@ -36,8 +36,8 @@ def test_schedule_factory_random_configs():
min_rate = params["min_rate"]
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
schedule_type,
optimizer,
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
@ -52,8 +52,8 @@ def test_schedule_factory_random_configs():
t_mult = params["t_mult"]
min_rate = params["min_rate"]
scheduler = SchedulerFactory.create(
optimizer,
schedule_type,
optimizer,
warmup_steps=warmup_steps,
cycle_length=cycle_length,
t_mult=t_mult,
@ -103,8 +103,8 @@ def test_schedule_factory_edge_cases():
min_rate = params["min_rate"]
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
"cosine",
optimizer,
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
@ -129,8 +129,8 @@ def test_schedule_factory_state_persistence():
min_rate = 0.1
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
"cosine",
optimizer,
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
@ -146,8 +146,8 @@ def test_schedule_factory_state_persistence():
# Create new scheduler with same parameters
new_scheduler = SchedulerFactory.create(
optimizer,
"cosine",
optimizer,
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,