refactor : 清理工厂和配置系统中的死代码与冗余抽象
- 删除 Registry 中未使用的 category/priority 字段,_entries 简化为直接存储类引用 - 修正 __init_subclass__ 避免叶子类(AutoRegressiveLM 等)创建空注册表 - 删除 5 个工厂的薄 create() 覆写,统一使用 BaseFactory.create(name, *args, **kwargs) - 删除 3 处零调用的 available_types/available_strategies 别名死代码 - 删除零调用的 BaseModelConfig.to_file 死代码 - 将 BaseConfig.from_json/to_json 重命名为 from_file/to_file,消除与子类重复 - 移除两个 inference builder 中总是被覆写的 prompt_tokens=0
This commit is contained in:
parent
e7b18b7c03
commit
6ae1828449
|
|
@ -89,10 +89,10 @@ class BaseConfig:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
def from_file(cls, path: Union[str, Path]) -> Self:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
return cls.from_dict(json.load(f))
|
return cls.from_dict(json.load(f))
|
||||||
|
|
||||||
def to_json(self, path: Union[str, Path]):
|
def to_file(self, path: Union[str, Path]):
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Self
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
from astrai.config.base import BaseConfig
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -22,18 +21,6 @@ class BaseModelConfig(BaseConfig):
|
||||||
|
|
||||||
model_type: Optional[str] = None
|
model_type: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(cls, config_path: str) -> Self:
|
|
||||||
with open(config_path, "r") as f:
|
|
||||||
raw: Dict[str, Any] = json.load(f)
|
|
||||||
return cls.from_dict(raw)
|
|
||||||
|
|
||||||
def to_file(self, config_path: str):
|
|
||||||
d = self.to_dict()
|
|
||||||
config_dict = {k: v for k, v in d.items() if v is not None}
|
|
||||||
with open(config_path, "w") as f:
|
|
||||||
json.dump(config_dict, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ConfigFactory.register("autoregressive_lm")
|
@ConfigFactory.register("autoregressive_lm")
|
||||||
|
|
|
||||||
|
|
@ -136,20 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
|
||||||
"""Create a dataset instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
|
||||||
window_size: Window size for data sampling
|
|
||||||
stride: Stride between consecutive samples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dataset instance
|
|
||||||
"""
|
|
||||||
return super().create(train_type, window_size, stride)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -179,11 +165,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available_types(cls) -> list:
|
|
||||||
"""Return list of registered dataset type names."""
|
|
||||||
return cls.list_registered()
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("seq")
|
@DatasetFactory.register("seq")
|
||||||
class SEQDataset(BaseDataset):
|
class SEQDataset(BaseDataset):
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,6 @@ from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
ForwardRef,
|
ForwardRef,
|
||||||
Generic,
|
Generic,
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
@ -56,21 +54,19 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
unrelated parameters.
|
unrelated parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_entries: Dict[str, Tuple[Type, Optional[str], int]]
|
_entries: Dict[str, Type[T]]
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
cls._entries = {}
|
|
||||||
for orig_base in getattr(cls, "__orig_bases__", ()):
|
for orig_base in getattr(cls, "__orig_bases__", ()):
|
||||||
if _get_origin(orig_base) is BaseFactory:
|
if _get_origin(orig_base) is BaseFactory:
|
||||||
(arg,) = _get_args(orig_base)
|
(arg,) = _get_args(orig_base)
|
||||||
|
cls._entries = {}
|
||||||
cls._component_base = _resolve_type(arg, cls)
|
cls._component_base = _resolve_type(arg, cls)
|
||||||
return
|
return
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(
|
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
|
||||||
cls, name: str, category: Optional[str] = None, priority: int = 0
|
|
||||||
) -> Callable[[Type[T]], Type[T]]:
|
|
||||||
"""Decorator to register a component class.
|
"""Decorator to register a component class.
|
||||||
|
|
||||||
Validates that the decorated class inherits from the generic
|
Validates that the decorated class inherits from the generic
|
||||||
|
|
@ -81,7 +77,7 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
cls._validate_component(component_cls)
|
cls._validate_component(component_cls)
|
||||||
if name in cls._entries:
|
if name in cls._entries:
|
||||||
raise ValueError(f"Component '{name}' is already registered")
|
raise ValueError(f"Component '{name}' is already registered")
|
||||||
cls._entries[name] = (component_cls, category, priority)
|
cls._entries[name] = component_cls
|
||||||
return component_cls
|
return component_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
@ -96,7 +92,7 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||||
)
|
)
|
||||||
component_cls = entry[0]
|
component_cls = entry
|
||||||
sig = inspect.signature(component_cls.__init__)
|
sig = inspect.signature(component_cls.__init__)
|
||||||
has_var_kwargs = any(
|
has_var_kwargs = any(
|
||||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||||
|
|
@ -130,7 +126,7 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||||
)
|
)
|
||||||
return entry[0]
|
return entry
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_registered(cls) -> list:
|
def list_registered(cls) -> list:
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=request.model,
|
model=request.model,
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
)
|
||||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||||
return prompt, ctx, stop_sequences
|
return prompt, ctx, stop_sequences
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,6 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
resp_id=self._resp_id,
|
resp_id=self._resp_id,
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=self._model,
|
model=self._model,
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
)
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
stop_sequences = (
|
stop_sequences = (
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class GenContext:
|
||||||
resp_id: str
|
resp_id: str
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int = 0
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class AttnFactory(BaseFactory[nn.Module]):
|
class AttnFactory(BaseFactory[nn.Module]):
|
||||||
@classmethod
|
pass
|
||||||
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
|
||||||
return super().create(attn_type, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@AttnFactory.register("gqa")
|
@AttnFactory.register("gqa")
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,7 @@ from astrai.model.components.linear import Linear
|
||||||
|
|
||||||
|
|
||||||
class FFNFactory(BaseFactory[nn.Module]):
|
class FFNFactory(BaseFactory[nn.Module]):
|
||||||
@classmethod
|
pass
|
||||||
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
|
||||||
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@FFNFactory.register("mlp")
|
@FFNFactory.register("mlp")
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class Pipeline:
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
config = PipelineConfig.from_file("sft_pipeline.json")
|
||||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
"""Factory class for creating learning rate schedulers.
|
"""Factory class for creating learning rate schedulers.
|
||||||
|
|
||||||
Supports decorator-based registration for extensible scheduler types.
|
Supports decorator-based registration for extensible scheduler types.
|
||||||
Also supports creation from ScheduleConfig objects.
|
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
@SchedulerFactory.register("custom")
|
@SchedulerFactory.register("custom")
|
||||||
|
|
@ -41,27 +40,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls, optimizer, schedule_type: str = "none", **kwargs
|
|
||||||
) -> "BaseScheduler":
|
|
||||||
"""Create a scheduler instance by type name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
optimizer: PyTorch optimizer
|
|
||||||
schedule_type: Type of scheduler ("cosine", "sgdr")
|
|
||||||
**kwargs: Arguments passed to the scheduler constructor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Scheduler instance
|
|
||||||
"""
|
|
||||||
return super().create(schedule_type, optimizer, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available_types(cls) -> list:
|
|
||||||
"""Return list of registered scheduler type names."""
|
|
||||||
return cls.list_registered()
|
|
||||||
|
|
||||||
|
|
||||||
# ----------- Scheduler implementations -----------
|
# ----------- Scheduler implementations -----------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,26 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
strategy = StrategyFactory.create("custom", model, device)
|
strategy = StrategyFactory.create("custom", model, device)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
|
||||||
"""Create a strategy instance based on training type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
|
||||||
model: Model instance for the strategy
|
|
||||||
device: Device to run the strategy on
|
|
||||||
**kwargs: Additional arguments passed to strategy constructor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Strategy instance
|
|
||||||
"""
|
|
||||||
return super().create(train_type, model, device, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available_strategies(cls) -> list:
|
|
||||||
"""Return list of registered strategy names."""
|
|
||||||
return cls.list_registered()
|
|
||||||
|
|
||||||
|
|
||||||
# ============== Strategy Classes ==============
|
# ============== Strategy Classes ==============
|
||||||
# All strategies are registered at class definition time using the decorator
|
# All strategies are registered at class definition time using the decorator
|
||||||
|
|
|
||||||
|
|
@ -172,8 +172,8 @@ class TrainContextBuilder:
|
||||||
obj.load_state_dict(extra[name])
|
obj.load_state_dict(extra[name])
|
||||||
|
|
||||||
context.strategy = StrategyFactory.create(
|
context.strategy = StrategyFactory.create(
|
||||||
|
cfg.strategy,
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=cfg.strategy,
|
|
||||||
device=device,
|
device=device,
|
||||||
executor=executor,
|
executor=executor,
|
||||||
model_fn=cfg.model_fn,
|
model_fn=cfg.model_fn,
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ def main():
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = PipelineConfig.from_json(args.config)
|
config = PipelineConfig.from_file(args.config)
|
||||||
|
|
||||||
Pipeline(
|
Pipeline(
|
||||||
config=config,
|
config=config,
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,8 @@ def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
optimizer: optim.Optimizer, **kwargs
|
optimizer: optim.Optimizer, **kwargs
|
||||||
) -> optim.lr_scheduler.LRScheduler:
|
) -> optim.lr_scheduler.LRScheduler:
|
||||||
return SchedulerFactory.create(optimizer, **kwargs)
|
schedule_type = kwargs.pop("schedule_type")
|
||||||
|
return SchedulerFactory.create(schedule_type, optimizer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def compute_total_steps(
|
def compute_total_steps(
|
||||||
|
|
|
||||||
|
|
@ -53,15 +53,15 @@ def test_to_dict_roundtrip():
|
||||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||||
|
|
||||||
|
|
||||||
def test_to_json_from_json(temp_dir):
|
def test_to_file_from_file(temp_dir):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
mask={"text": "train"},
|
mask={"text": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
)
|
)
|
||||||
path = os.path.join(temp_dir, "config.json")
|
path = os.path.join(temp_dir, "config.json")
|
||||||
config.to_json(path)
|
config.to_file(path)
|
||||||
loaded = PipelineConfig.from_json(path)
|
loaded = PipelineConfig.from_file(path)
|
||||||
assert loaded.input.sections == _TEXT_SECTIONS
|
assert loaded.input.sections == _TEXT_SECTIONS
|
||||||
assert loaded.mask == {"text": "train"}
|
assert loaded.mask == {"text": "train"}
|
||||||
|
|
||||||
|
|
@ -69,8 +69,8 @@ def test_to_json_from_json(temp_dir):
|
||||||
def test_dpo_config_roundtrip(temp_dir):
|
def test_dpo_config_roundtrip(temp_dir):
|
||||||
config = make_dpo_chat_config()
|
config = make_dpo_chat_config()
|
||||||
path = os.path.join(temp_dir, "config.json")
|
path = os.path.join(temp_dir, "config.json")
|
||||||
config.to_json(path)
|
config.to_file(path)
|
||||||
loaded = PipelineConfig.from_json(path)
|
loaded = PipelineConfig.from_file(path)
|
||||||
assert loaded.input.sources is not None
|
assert loaded.input.sources is not None
|
||||||
assert "chosen" in loaded.input.sources
|
assert "chosen" in loaded.input.sources
|
||||||
assert "rejected" in loaded.input.sources
|
assert "rejected" in loaded.input.sources
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ def create_train_config(
|
||||||
|
|
||||||
def scheduler_fn(optim):
|
def scheduler_fn(optim):
|
||||||
return SchedulerFactory.create(
|
return SchedulerFactory.create(
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
return TrainConfig(
|
return TrainConfig(
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
||||||
|
|
||||||
def scheduler_fn(optim):
|
def scheduler_fn(optim):
|
||||||
return SchedulerFactory.create(
|
return SchedulerFactory.create(
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
|
|
@ -136,7 +136,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
|
|
||||||
def scheduler_fn(optim):
|
def scheduler_fn(optim):
|
||||||
return SchedulerFactory.create(
|
return SchedulerFactory.create(
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
|
|
||||||
def scheduler_fn(optim):
|
def scheduler_fn(optim):
|
||||||
return SchedulerFactory.create(
|
return SchedulerFactory.create(
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
"cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,8 @@ def test_schedule_factory_random_configs():
|
||||||
min_rate = params["min_rate"]
|
min_rate = params["min_rate"]
|
||||||
lr_decay_steps = total_steps - warmup_steps
|
lr_decay_steps = total_steps - warmup_steps
|
||||||
scheduler = SchedulerFactory.create(
|
scheduler = SchedulerFactory.create(
|
||||||
optimizer,
|
|
||||||
schedule_type,
|
schedule_type,
|
||||||
|
optimizer,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
lr_decay_steps=lr_decay_steps,
|
lr_decay_steps=lr_decay_steps,
|
||||||
min_rate=min_rate,
|
min_rate=min_rate,
|
||||||
|
|
@ -52,8 +52,8 @@ def test_schedule_factory_random_configs():
|
||||||
t_mult = params["t_mult"]
|
t_mult = params["t_mult"]
|
||||||
min_rate = params["min_rate"]
|
min_rate = params["min_rate"]
|
||||||
scheduler = SchedulerFactory.create(
|
scheduler = SchedulerFactory.create(
|
||||||
optimizer,
|
|
||||||
schedule_type,
|
schedule_type,
|
||||||
|
optimizer,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
cycle_length=cycle_length,
|
cycle_length=cycle_length,
|
||||||
t_mult=t_mult,
|
t_mult=t_mult,
|
||||||
|
|
@ -103,8 +103,8 @@ def test_schedule_factory_edge_cases():
|
||||||
min_rate = params["min_rate"]
|
min_rate = params["min_rate"]
|
||||||
lr_decay_steps = total_steps - warmup_steps
|
lr_decay_steps = total_steps - warmup_steps
|
||||||
scheduler = SchedulerFactory.create(
|
scheduler = SchedulerFactory.create(
|
||||||
optimizer,
|
|
||||||
"cosine",
|
"cosine",
|
||||||
|
optimizer,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
lr_decay_steps=lr_decay_steps,
|
lr_decay_steps=lr_decay_steps,
|
||||||
min_rate=min_rate,
|
min_rate=min_rate,
|
||||||
|
|
@ -129,8 +129,8 @@ def test_schedule_factory_state_persistence():
|
||||||
min_rate = 0.1
|
min_rate = 0.1
|
||||||
lr_decay_steps = total_steps - warmup_steps
|
lr_decay_steps = total_steps - warmup_steps
|
||||||
scheduler = SchedulerFactory.create(
|
scheduler = SchedulerFactory.create(
|
||||||
optimizer,
|
|
||||||
"cosine",
|
"cosine",
|
||||||
|
optimizer,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
lr_decay_steps=lr_decay_steps,
|
lr_decay_steps=lr_decay_steps,
|
||||||
min_rate=min_rate,
|
min_rate=min_rate,
|
||||||
|
|
@ -146,8 +146,8 @@ def test_schedule_factory_state_persistence():
|
||||||
|
|
||||||
# Create new scheduler with same parameters
|
# Create new scheduler with same parameters
|
||||||
new_scheduler = SchedulerFactory.create(
|
new_scheduler = SchedulerFactory.create(
|
||||||
optimizer,
|
|
||||||
"cosine",
|
"cosine",
|
||||||
|
optimizer,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
lr_decay_steps=lr_decay_steps,
|
lr_decay_steps=lr_decay_steps,
|
||||||
min_rate=min_rate,
|
min_rate=min_rate,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue