22 lines
598 B
Python
22 lines
598 B
Python
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
|
|
|
from typing import Any, Protocol, runtime_checkable
|
|
|
|
|
|
@runtime_checkable
|
|
class OptimizerProtocol(Protocol):
|
|
def step(self, closure=None): ...
|
|
def zero_grad(self): ...
|
|
@property
|
|
def param_groups(self) -> Any: ...
|
|
def state_dict(self) -> dict: ...
|
|
def load_state_dict(self, d: dict): ...
|
|
|
|
|
|
@runtime_checkable
|
|
class SchedulerProtocol(Protocol):
|
|
def step(self): ...
|
|
def state_dict(self) -> dict: ...
|
|
def load_state_dict(self, d: dict): ...
|
|
def get_last_lr(self): ...
|