refactor : 清理工厂和配置系统中的死代码与冗余抽象
- 删除 Registry 中未使用的 category/priority 字段,_entries 简化为直接存储类引用 - 修正 __init_subclass__ 避免叶子类(AutoRegressiveLM 等)创建空注册表 - 删除 5 个工厂的薄 create() 覆写,统一使用 BaseFactory.create(name, *args, **kwargs) - 删除 3 处零调用的 available_types/available_strategies 别名死代码 - 删除零调用的 BaseModelConfig.to_file 死代码 - 将 BaseConfig.from_json/to_json 重命名为 from_file/to_file,消除与子类重复 - 移除两个 inference builder 中总是被覆写的 prompt_tokens=0
This commit is contained in:
parent
e7b18b7c03
commit
6ae1828449
|
|
@ -89,10 +89,10 @@ class BaseConfig:
|
|||
raise TypeError
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
||||
def from_file(cls, path: Union[str, Path]) -> Self:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
def to_json(self, path: Union[str, Path]):
|
||||
def to_file(self, path: Union[str, Path]):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Self
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
from astrai.factory import BaseFactory
|
||||
|
|
@ -22,18 +21,6 @@ class BaseModelConfig(BaseConfig):
|
|||
|
||||
model_type: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str) -> Self:
|
||||
with open(config_path, "r") as f:
|
||||
raw: Dict[str, Any] = json.load(f)
|
||||
return cls.from_dict(raw)
|
||||
|
||||
def to_file(self, config_path: str):
|
||||
d = self.to_dict()
|
||||
config_dict = {k: v for k, v in d.items() if v is not None}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ConfigFactory.register("autoregressive_lm")
|
||||
|
|
|
|||
|
|
@ -136,20 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
||||
"""Create a dataset instance.
|
||||
|
||||
Args:
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples
|
||||
|
||||
Returns:
|
||||
Dataset instance
|
||||
"""
|
||||
return super().create(train_type, window_size, stride)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
|
@ -179,11 +165,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
|
||||
return dataset
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered dataset type names."""
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
@DatasetFactory.register("seq")
|
||||
class SEQDataset(BaseDataset):
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ from typing import (
|
|||
Dict,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
|
@ -56,21 +54,19 @@ class BaseFactory(ABC, Generic[T]):
|
|||
unrelated parameters.
|
||||
"""
|
||||
|
||||
_entries: Dict[str, Tuple[Type, Optional[str], int]]
|
||||
_entries: Dict[str, Type[T]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._entries = {}
|
||||
for orig_base in getattr(cls, "__orig_bases__", ()):
|
||||
if _get_origin(orig_base) is BaseFactory:
|
||||
(arg,) = _get_args(orig_base)
|
||||
cls._entries = {}
|
||||
cls._component_base = _resolve_type(arg, cls)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, name: str, category: Optional[str] = None, priority: int = 0
|
||||
) -> Callable[[Type[T]], Type[T]]:
|
||||
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Decorator to register a component class.
|
||||
|
||||
Validates that the decorated class inherits from the generic
|
||||
|
|
@ -81,7 +77,7 @@ class BaseFactory(ABC, Generic[T]):
|
|||
cls._validate_component(component_cls)
|
||||
if name in cls._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
cls._entries[name] = (component_cls, category, priority)
|
||||
cls._entries[name] = component_cls
|
||||
return component_cls
|
||||
|
||||
return decorator
|
||||
|
|
@ -96,7 +92,7 @@ class BaseFactory(ABC, Generic[T]):
|
|||
raise ValueError(
|
||||
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||
)
|
||||
component_cls = entry[0]
|
||||
component_cls = entry
|
||||
sig = inspect.signature(component_cls.__init__)
|
||||
has_var_kwargs = any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||
|
|
@ -130,7 +126,7 @@ class BaseFactory(ABC, Generic[T]):
|
|||
raise ValueError(
|
||||
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||
)
|
||||
return entry[0]
|
||||
return entry
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
|||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||
return prompt, ctx, stop_sequences
|
||||
|
|
|
|||
|
|
@ -86,7 +86,6 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
resp_id=self._resp_id,
|
||||
created=int(time.time()),
|
||||
model=self._model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop = request.stop
|
||||
stop_sequences = (
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class GenContext:
|
|||
resp_id: str
|
||||
created: int
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,9 +24,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|||
|
||||
|
||||
class AttnFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
||||
return super().create(attn_type, **kwargs)
|
||||
pass
|
||||
|
||||
|
||||
@AttnFactory.register("gqa")
|
||||
|
|
|
|||
|
|
@ -8,9 +8,7 @@ from astrai.model.components.linear import Linear
|
|||
|
||||
|
||||
class FFNFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
||||
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
||||
pass
|
||||
|
||||
|
||||
@FFNFactory.register("mlp")
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class Pipeline:
|
|||
|
||||
Usage::
|
||||
|
||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||
config = PipelineConfig.from_file("sft_pipeline.json")
|
||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|||
"""Factory class for creating learning rate schedulers.
|
||||
|
||||
Supports decorator-based registration for extensible scheduler types.
|
||||
Also supports creation from ScheduleConfig objects.
|
||||
|
||||
Example usage:
|
||||
@SchedulerFactory.register("custom")
|
||||
|
|
@ -41,27 +40,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, optimizer, schedule_type: str = "none", **kwargs
|
||||
) -> "BaseScheduler":
|
||||
"""Create a scheduler instance by type name.
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
schedule_type: Type of scheduler ("cosine", "sgdr")
|
||||
**kwargs: Arguments passed to the scheduler constructor
|
||||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
"""
|
||||
return super().create(schedule_type, optimizer, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered scheduler type names."""
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
# ----------- Scheduler implementations -----------
|
||||
|
||||
|
|
|
|||
|
|
@ -127,26 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|||
strategy = StrategyFactory.create("custom", model, device)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
||||
"""Create a strategy instance based on training type.
|
||||
|
||||
Args:
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
model: Model instance for the strategy
|
||||
device: Device to run the strategy on
|
||||
**kwargs: Additional arguments passed to strategy constructor
|
||||
|
||||
Returns:
|
||||
Strategy instance
|
||||
"""
|
||||
return super().create(train_type, model, device, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_strategies(cls) -> list:
|
||||
"""Return list of registered strategy names."""
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
# ============== Strategy Classes ==============
|
||||
# All strategies are registered at class definition time using the decorator
|
||||
|
|
|
|||
|
|
@ -172,8 +172,8 @@ class TrainContextBuilder:
|
|||
obj.load_state_dict(extra[name])
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
cfg.strategy,
|
||||
model=context.model,
|
||||
train_type=cfg.strategy,
|
||||
device=device,
|
||||
executor=executor,
|
||||
model_fn=cfg.model_fn,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def main():
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = PipelineConfig.from_json(args.config)
|
||||
config = PipelineConfig.from_file(args.config)
|
||||
|
||||
Pipeline(
|
||||
config=config,
|
||||
|
|
|
|||
|
|
@ -231,7 +231,8 @@ def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
|||
def create_scheduler(
|
||||
optimizer: optim.Optimizer, **kwargs
|
||||
) -> optim.lr_scheduler.LRScheduler:
|
||||
return SchedulerFactory.create(optimizer, **kwargs)
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
return SchedulerFactory.create(schedule_type, optimizer, **kwargs)
|
||||
|
||||
|
||||
def compute_total_steps(
|
||||
|
|
|
|||
|
|
@ -53,15 +53,15 @@ def test_to_dict_roundtrip():
|
|||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||
|
||||
|
||||
def test_to_json_from_json(temp_dir):
|
||||
def test_to_file_from_file(temp_dir):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
mask={"text": "train"},
|
||||
mask_default="mask",
|
||||
)
|
||||
path = os.path.join(temp_dir, "config.json")
|
||||
config.to_json(path)
|
||||
loaded = PipelineConfig.from_json(path)
|
||||
config.to_file(path)
|
||||
loaded = PipelineConfig.from_file(path)
|
||||
assert loaded.input.sections == _TEXT_SECTIONS
|
||||
assert loaded.mask == {"text": "train"}
|
||||
|
||||
|
|
@ -69,8 +69,8 @@ def test_to_json_from_json(temp_dir):
|
|||
def test_dpo_config_roundtrip(temp_dir):
|
||||
config = make_dpo_chat_config()
|
||||
path = os.path.join(temp_dir, "config.json")
|
||||
config.to_json(path)
|
||||
loaded = PipelineConfig.from_json(path)
|
||||
config.to_file(path)
|
||||
loaded = PipelineConfig.from_file(path)
|
||||
assert loaded.input.sources is not None
|
||||
assert "chosen" in loaded.input.sources
|
||||
assert "rejected" in loaded.input.sources
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def create_train_config(
|
|||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
return TrainConfig(
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
|||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
|
|
@ -136,7 +136,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ def test_schedule_factory_random_configs():
|
|||
min_rate = params["min_rate"]
|
||||
lr_decay_steps = total_steps - warmup_steps
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
schedule_type,
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
|
|
@ -52,8 +52,8 @@ def test_schedule_factory_random_configs():
|
|||
t_mult = params["t_mult"]
|
||||
min_rate = params["min_rate"]
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
schedule_type,
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
cycle_length=cycle_length,
|
||||
t_mult=t_mult,
|
||||
|
|
@ -103,8 +103,8 @@ def test_schedule_factory_edge_cases():
|
|||
min_rate = params["min_rate"]
|
||||
lr_decay_steps = total_steps - warmup_steps
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
"cosine",
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
|
|
@ -129,8 +129,8 @@ def test_schedule_factory_state_persistence():
|
|||
min_rate = 0.1
|
||||
lr_decay_steps = total_steps - warmup_steps
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
"cosine",
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
|
|
@ -146,8 +146,8 @@ def test_schedule_factory_state_persistence():
|
|||
|
||||
# Create new scheduler with same parameters
|
||||
new_scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
"cosine",
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
|
|
|
|||
Loading…
Reference in New Issue