refactor : 清理工厂和配置系统中的死代码与冗余抽象

- 删除 Registry 中未使用的 category/priority 字段,_entries 简化为直接存储类引用
- 修正 __init_subclass__ 避免叶子类(AutoRegressiveLM 等)创建空注册表
- 删除 5 个工厂的薄 create() 覆写,统一使用 BaseFactory.create(name, *args, **kwargs)
- 删除 3 处零调用的 available_types/available_strategies 别名死代码
- 删除零调用的 BaseModelConfig.to_file 死代码
- 将 BaseConfig.from_json/to_json 重命名为 from_file/to_file,消除与子类重复
- 移除两个 inference builder 中总是被覆写的 prompt_tokens=0
This commit is contained in:
ViperEkura 2026-06-07 11:39:50 +08:00
parent e7b18b7c03
commit 6ae1828449
20 changed files with 31 additions and 114 deletions

View File

@ -89,10 +89,10 @@ class BaseConfig:
raise TypeError raise TypeError
@classmethod @classmethod
def from_json(cls, path: Union[str, Path]) -> Self: def from_file(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f)) return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]): def to_file(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f: with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -1,6 +1,5 @@
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Self from typing import Any, Dict, Optional
from astrai.config.base import BaseConfig from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -22,18 +21,6 @@ class BaseModelConfig(BaseConfig):
model_type: Optional[str] = None model_type: Optional[str] = None
@classmethod
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
return cls.from_dict(raw)
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@dataclass @dataclass
@ConfigFactory.register("autoregressive_lm") @ConfigFactory.register("autoregressive_lm")

View File

@ -136,20 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
dataset = DatasetFactory.create("custom", window_size, stride) dataset = DatasetFactory.create("custom", window_size, stride)
""" """
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling
stride: Stride between consecutive samples
Returns:
Dataset instance
"""
return super().create(train_type, window_size, stride)
@classmethod @classmethod
def load( def load(
cls, cls,
@ -179,11 +165,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
return dataset return dataset
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
return cls.list_registered()
@DatasetFactory.register("seq") @DatasetFactory.register("seq")
class SEQDataset(BaseDataset): class SEQDataset(BaseDataset):

View File

@ -8,8 +8,6 @@ from typing import (
Dict, Dict,
ForwardRef, ForwardRef,
Generic, Generic,
Optional,
Tuple,
Type, Type,
TypeVar, TypeVar,
) )
@ -56,21 +54,19 @@ class BaseFactory(ABC, Generic[T]):
unrelated parameters. unrelated parameters.
""" """
_entries: Dict[str, Tuple[Type, Optional[str], int]] _entries: Dict[str, Type[T]]
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
cls._entries = {}
for orig_base in getattr(cls, "__orig_bases__", ()): for orig_base in getattr(cls, "__orig_bases__", ()):
if _get_origin(orig_base) is BaseFactory: if _get_origin(orig_base) is BaseFactory:
(arg,) = _get_args(orig_base) (arg,) = _get_args(orig_base)
cls._entries = {}
cls._component_base = _resolve_type(arg, cls) cls._component_base = _resolve_type(arg, cls)
return return
@classmethod @classmethod
def register( def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
cls, name: str, category: Optional[str] = None, priority: int = 0
) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class. """Decorator to register a component class.
Validates that the decorated class inherits from the generic Validates that the decorated class inherits from the generic
@ -81,7 +77,7 @@ class BaseFactory(ABC, Generic[T]):
cls._validate_component(component_cls) cls._validate_component(component_cls)
if name in cls._entries: if name in cls._entries:
raise ValueError(f"Component '{name}' is already registered") raise ValueError(f"Component '{name}' is already registered")
cls._entries[name] = (component_cls, category, priority) cls._entries[name] = component_cls
return component_cls return component_cls
return decorator return decorator
@ -96,7 +92,7 @@ class BaseFactory(ABC, Generic[T]):
raise ValueError( raise ValueError(
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}" f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
) )
component_cls = entry[0] component_cls = entry
sig = inspect.signature(component_cls.__init__) sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any( has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
@ -130,7 +126,7 @@ class BaseFactory(ABC, Generic[T]):
raise ValueError( raise ValueError(
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}" f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
) )
return entry[0] return entry
@classmethod @classmethod
def list_registered(cls) -> list: def list_registered(cls) -> list:

View File

@ -42,7 +42,6 @@ class AnthropicResponseBuilder(ResponseBuilder):
resp_id=f"msg_{uuid.uuid4().hex[:24]}", resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=int(time.time()), created=int(time.time()),
model=request.model, model=request.model,
prompt_tokens=0,
) )
stop_sequences = getattr(request, "stop_sequences", None) or [] stop_sequences = getattr(request, "stop_sequences", None) or []
return prompt, ctx, stop_sequences return prompt, ctx, stop_sequences

View File

@ -86,7 +86,6 @@ class OpenAIResponseBuilder(ResponseBuilder):
resp_id=self._resp_id, resp_id=self._resp_id,
created=int(time.time()), created=int(time.time()),
model=self._model, model=self._model,
prompt_tokens=0,
) )
stop = request.stop stop = request.stop
stop_sequences = ( stop_sequences = (

View File

@ -35,7 +35,7 @@ class GenContext:
resp_id: str resp_id: str
created: int created: int
model: str model: str
prompt_tokens: int prompt_tokens: int = 0
completion_tokens: int = 0 completion_tokens: int = 0

View File

@ -24,9 +24,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
class AttnFactory(BaseFactory[nn.Module]): class AttnFactory(BaseFactory[nn.Module]):
@classmethod pass
def create(cls, attn_type: str, **kwargs) -> nn.Module:
return super().create(attn_type, **kwargs)
@AttnFactory.register("gqa") @AttnFactory.register("gqa")

View File

@ -8,9 +8,7 @@ from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]): class FFNFactory(BaseFactory[nn.Module]):
@classmethod pass
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
@FFNFactory.register("mlp") @FFNFactory.register("mlp")

View File

@ -44,7 +44,7 @@ class Pipeline:
Usage:: Usage::
config = PipelineConfig.from_json("sft_pipeline.json") config = PipelineConfig.from_file("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run() Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
""" """

View File

@ -31,7 +31,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""Factory class for creating learning rate schedulers. """Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types. Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage: Example usage:
@SchedulerFactory.register("custom") @SchedulerFactory.register("custom")
@ -41,27 +40,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs) scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
""" """
@classmethod
def create(
cls, optimizer, schedule_type: str = "none", **kwargs
) -> "BaseScheduler":
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
"""
return super().create(schedule_type, optimizer, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return cls.list_registered()
# ----------- Scheduler implementations ----------- # ----------- Scheduler implementations -----------

View File

@ -127,26 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
strategy = StrategyFactory.create("custom", model, device) strategy = StrategyFactory.create("custom", model, device)
""" """
@classmethod
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
model: Model instance for the strategy
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
"""
return super().create(train_type, model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
return cls.list_registered()
# ============== Strategy Classes ============== # ============== Strategy Classes ==============
# All strategies are registered at class definition time using the decorator # All strategies are registered at class definition time using the decorator

View File

@ -172,8 +172,8 @@ class TrainContextBuilder:
obj.load_state_dict(extra[name]) obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create( context.strategy = StrategyFactory.create(
cfg.strategy,
model=context.model, model=context.model,
train_type=cfg.strategy,
device=device, device=device,
executor=executor, executor=executor,
model_fn=cfg.model_fn, model_fn=cfg.model_fn,

View File

@ -24,7 +24,7 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
config = PipelineConfig.from_json(args.config) config = PipelineConfig.from_file(args.config)
Pipeline( Pipeline(
config=config, config=config,

View File

@ -231,7 +231,8 @@ def create_optimizer(model, **kwargs) -> optim.Optimizer:
def create_scheduler( def create_scheduler(
optimizer: optim.Optimizer, **kwargs optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler: ) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.create(optimizer, **kwargs) schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(schedule_type, optimizer, **kwargs)
def compute_total_steps( def compute_total_steps(

View File

@ -53,15 +53,15 @@ def test_to_dict_roundtrip():
assert config2.mask == {"prompt": "mask", "response": "train"} assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(temp_dir): def test_to_file_from_file(temp_dir):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS), input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"}, mask={"text": "train"},
mask_default="mask", mask_default="mask",
) )
path = os.path.join(temp_dir, "config.json") path = os.path.join(temp_dir, "config.json")
config.to_json(path) config.to_file(path)
loaded = PipelineConfig.from_json(path) loaded = PipelineConfig.from_file(path)
assert loaded.input.sections == _TEXT_SECTIONS assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"} assert loaded.mask == {"text": "train"}
@ -69,8 +69,8 @@ def test_to_json_from_json(temp_dir):
def test_dpo_config_roundtrip(temp_dir): def test_dpo_config_roundtrip(temp_dir):
config = make_dpo_chat_config() config = make_dpo_chat_config()
path = os.path.join(temp_dir, "config.json") path = os.path.join(temp_dir, "config.json")
config.to_json(path) config.to_file(path)
loaded = PipelineConfig.from_json(path) loaded = PipelineConfig.from_file(path)
assert loaded.input.sources is not None assert loaded.input.sources is not None
assert "chosen" in loaded.input.sources assert "chosen" in loaded.input.sources
assert "rejected" in loaded.input.sources assert "rejected" in loaded.input.sources

View File

@ -65,7 +65,7 @@ def create_train_config(
def scheduler_fn(optim): def scheduler_fn(optim):
return SchedulerFactory.create( return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
) )
return TrainConfig( return TrainConfig(

View File

@ -102,7 +102,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
def scheduler_fn(optim): def scheduler_fn(optim):
return SchedulerFactory.create( return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
) )
train_config = TrainConfig( train_config = TrainConfig(
@ -136,7 +136,7 @@ def test_callback_integration(base_test_env, random_dataset):
def scheduler_fn(optim): def scheduler_fn(optim):
return SchedulerFactory.create( return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
) )
train_config = TrainConfig( train_config = TrainConfig(

View File

@ -16,7 +16,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
def scheduler_fn(optim): def scheduler_fn(optim):
return SchedulerFactory.create( return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
) )
train_config = TrainConfig( train_config = TrainConfig(

View File

@ -36,8 +36,8 @@ def test_schedule_factory_random_configs():
min_rate = params["min_rate"] min_rate = params["min_rate"]
lr_decay_steps = total_steps - warmup_steps lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create( scheduler = SchedulerFactory.create(
optimizer,
schedule_type, schedule_type,
optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps, lr_decay_steps=lr_decay_steps,
min_rate=min_rate, min_rate=min_rate,
@ -52,8 +52,8 @@ def test_schedule_factory_random_configs():
t_mult = params["t_mult"] t_mult = params["t_mult"]
min_rate = params["min_rate"] min_rate = params["min_rate"]
scheduler = SchedulerFactory.create( scheduler = SchedulerFactory.create(
optimizer,
schedule_type, schedule_type,
optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
cycle_length=cycle_length, cycle_length=cycle_length,
t_mult=t_mult, t_mult=t_mult,
@ -103,8 +103,8 @@ def test_schedule_factory_edge_cases():
min_rate = params["min_rate"] min_rate = params["min_rate"]
lr_decay_steps = total_steps - warmup_steps lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create( scheduler = SchedulerFactory.create(
optimizer,
"cosine", "cosine",
optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps, lr_decay_steps=lr_decay_steps,
min_rate=min_rate, min_rate=min_rate,
@ -129,8 +129,8 @@ def test_schedule_factory_state_persistence():
min_rate = 0.1 min_rate = 0.1
lr_decay_steps = total_steps - warmup_steps lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create( scheduler = SchedulerFactory.create(
optimizer,
"cosine", "cosine",
optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps, lr_decay_steps=lr_decay_steps,
min_rate=min_rate, min_rate=min_rate,
@ -146,8 +146,8 @@ def test_schedule_factory_state_persistence():
# Create new scheduler with same parameters # Create new scheduler with same parameters
new_scheduler = SchedulerFactory.create( new_scheduler = SchedulerFactory.create(
optimizer,
"cosine", "cosine",
optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps, lr_decay_steps=lr_decay_steps,
min_rate=min_rate, min_rate=min_rate,