fix : 修复策略相关文件的类型注解与抽象方法体

- 修复 strategy.py 单元素 Union 与缺失的参数/返回类型注解
- 修复 train_context.py 8 个 default=None 字段缺 Optional 标记
- 修复 sample.py/packing.py/position_id.py 方法缺参数及返回类型注解
- 修复 factory.py _resolve_type/list_registered 缺类型注解
- 修复 train_config.py 裸 dict/list 缺泛型参数
- abstractmethod body 从 ... 改为 raise NotImplementedError
- feat : checkpoint meta.json 保存 TrainConfig 超参供人工查阅
This commit is contained in:
ViperEkura 2026-06-14 16:20:10 +08:00
parent a2512f8a5a
commit fec376b0dd
8 changed files with 70 additions and 30 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields
from typing import Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn
from torch.optim import Optimizer
@ -40,7 +40,7 @@ class TrainConfig(BaseConfig):
max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."}
)
gradient_checkpointing_modules: list = field(
gradient_checkpointing_modules: List[str] = field(
default_factory=list,
metadata={"help": "Module types to enable activation checkpointing for."},
)
@ -133,11 +133,11 @@ class TrainConfig(BaseConfig):
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
)
executor_kwargs: dict = field(
executor_kwargs: Dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
)
extra_kwargs: dict = field(
extra_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Other arguments."}
)

View File

@ -4,12 +4,16 @@ import inspect
import sys
from abc import ABC
from typing import (
Any,
Callable,
Dict,
ForwardRef,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
)
from typing import get_args as _get_args
from typing import get_origin as _get_origin
@ -17,7 +21,9 @@ from typing import get_origin as _get_origin
T = TypeVar("T")
def _resolve_type(arg, factory_cls: type):
def _resolve_type(
arg: Union[Type, str, ForwardRef], factory_cls: type
) -> Optional[Type]:
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
if not isinstance(arg, (str, ForwardRef)):
return arg
@ -129,7 +135,7 @@ class BaseFactory(ABC, Generic[T]):
return entry
@classmethod
def list_registered(cls) -> list:
def list_registered(cls) -> List[str]:
"""List all registered component names."""
return sorted(cls._entries)

View File

@ -29,6 +29,7 @@ class BaseSamplingStrategy(ABC):
Returns:
Transformed logits tensor.
"""
raise NotImplementedError
class TemperatureStrategy(BaseSamplingStrategy):
@ -41,7 +42,7 @@ class TemperatureStrategy(BaseSamplingStrategy):
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
t = self.temperature
if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
@ -63,7 +64,7 @@ class TopKStrategy(BaseSamplingStrategy):
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
tk = self.top_k
if isinstance(tk, Tensor):
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
@ -100,7 +101,9 @@ class TopPStrategy(BaseSamplingStrategy):
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
def _apply(
self, logits: Tensor, top_p: Union[float, Tensor], filter_value: float
) -> Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
@ -111,7 +114,7 @@ class TopPStrategy(BaseSamplingStrategy):
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
@ -142,7 +145,7 @@ class SamplingPipeline(BaseSamplingStrategy):
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits

View File

@ -12,7 +12,7 @@ from typing import Dict, List, Tuple
from astrai.factory import BaseFactory
def _truncate(seq: list, max_len: int, mode: str) -> list:
def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
@ -26,10 +26,11 @@ class PackingStrategy(ABC):
@abstractmethod
def apply(
self,
keys: Dict[str, List[list]],
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[list]]: ...
) -> Dict[str, List[List[int]]]:
raise NotImplementedError
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@ -38,7 +39,12 @@ class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@PackingStrategyFactory.register("simple")
class SimplePacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode):
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
return {
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
for k, vals in keys.items()
@ -47,7 +53,12 @@ class SimplePacking(PackingStrategy):
@PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode):
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
sequences = keys.get("sequence", [])
if not sequences:
return keys
@ -61,7 +72,7 @@ class BFDPacking(PackingStrategy):
return dict(reordered)
@staticmethod
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]:
n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []

View File

@ -16,7 +16,8 @@ class PositionIdStrategy(ABC):
"""Generate ``position_ids`` for packed sequences."""
@abstractmethod
def generate(self, sequences: List[list]) -> List[int]: ...
def generate(self, sequences: List[List[int]]) -> List[int]:
raise NotImplementedError
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
@ -25,13 +26,13 @@ class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
@PositionIdStrategyFactory.register("none")
class NoPositionId(PositionIdStrategy):
def generate(self, sequences):
def generate(self, sequences: List[List[int]]) -> List[int]:
return []
@PositionIdStrategyFactory.register("doc_reset")
class DocResetPositionId(PositionIdStrategy):
def generate(self, sequences):
def generate(self, sequences: List[List[int]]) -> List[int]:
pos_ids = []
for seq in sequences:
pos_ids.extend(range(len(seq)))
@ -40,6 +41,6 @@ class DocResetPositionId(PositionIdStrategy):
@PositionIdStrategyFactory.register("continuous")
class ContinuousPositionId(PositionIdStrategy):
def generate(self, sequences):
def generate(self, sequences: List[List[int]]) -> List[int]:
total = sum(len(seq) for seq in sequences)
return list(range(total))

View File

@ -1,7 +1,7 @@
"""Training strategy implementations with factory pattern."""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
from typing import Callable, Dict, Union
import torch
import torch.nn as nn
@ -11,7 +11,9 @@ from torch import Tensor
from astrai.factory import BaseFactory
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
def create_ref_model(
model_fn: Callable[[], nn.Module], state_dict: Dict[str, Tensor]
) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict."""
ref_model = model_fn()
ref_model.load_state_dict(state_dict)
@ -20,7 +22,7 @@ def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
return ref_model
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
def move_to_device(batch: Dict[str, Tensor], device: str) -> Dict[str, Tensor]:
"""Move batch tensors to specified device with non-blocking transfer."""
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
@ -30,7 +32,7 @@ def get_logprobs(
input_ids: Tensor,
mask: Tensor,
reduction: str,
):
) -> Tensor:
"""Compute token-wise log probabilities from model outputs.
Args:
@ -88,7 +90,10 @@ class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
def __init__(
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
**kwargs,
):
self.model = model
self.device = device
@ -139,7 +144,13 @@ class SEQStrategy(BaseStrategy):
Computes cross-entropy loss for next token prediction.
"""
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
def __init__(
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
label_smoothing: float = 0.0,
**kwargs,
):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing
@ -164,7 +175,13 @@ class SFTStrategy(BaseStrategy):
Applies cross-entropy loss only to tokens where loss_mask is True.
"""
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
def __init__(
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
label_smoothing: float = 0.0,
**kwargs,
):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing

View File

@ -154,11 +154,13 @@ class CheckpointCallback(TrainCallback):
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context)
meta = context.config.to_dict()
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=meta,
config=context.model_config,
)
context.checkpoint.save(save_path)

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self
from typing import Any, Dict, Optional, Self
import torch
import torch.nn as nn
@ -36,7 +36,7 @@ class TrainContext:
world_size: int = field(default=1)
rank: int = field(default=0)
kwargs: dict = field(default_factory=dict)
kwargs: Dict[str, Any] = field(default_factory=dict)
class TrainContextBuilder: