fix : 修复策略相关文件的类型注解与抽象方法体

- 修复 strategy.py 单元素 Union 与缺失的参数/返回类型注解
- 修复 train_context.py 8 个 default=None 字段缺 Optional 标记
- 修复 sample.py/packing.py/position_id.py 方法缺参数及返回类型注解
- 修复 factory.py _resolve_type/list_registered 缺类型注解
- 修复 train_config.py 裸 dict/list 缺泛型参数
- abstractmethod body 从 ... 改为 raise NotImplementedError
- feat : checkpoint meta.json 保存 TrainConfig 超参供人工查阅
This commit is contained in:
ViperEkura 2026-06-14 16:20:10 +08:00
parent a2512f8a5a
commit fec376b0dd
8 changed files with 70 additions and 30 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Callable, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -40,7 +40,7 @@ class TrainConfig(BaseConfig):
max_grad_norm: float = field( max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."} default=1.0, metadata={"help": "Maximum gradient norm."}
) )
gradient_checkpointing_modules: list = field( gradient_checkpointing_modules: List[str] = field(
default_factory=list, default_factory=list,
metadata={"help": "Module types to enable activation checkpointing for."}, metadata={"help": "Module types to enable activation checkpointing for."},
) )
@ -133,11 +133,11 @@ class TrainConfig(BaseConfig):
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."}, metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
) )
executor_kwargs: dict = field( executor_kwargs: Dict[str, Any] = field(
default_factory=dict, default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."}, metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
) )
extra_kwargs: dict = field( extra_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Other arguments."} default_factory=dict, metadata={"help": "Other arguments."}
) )

View File

@ -4,12 +4,16 @@ import inspect
import sys import sys
from abc import ABC from abc import ABC
from typing import ( from typing import (
Any,
Callable, Callable,
Dict, Dict,
ForwardRef, ForwardRef,
Generic, Generic,
List,
Optional,
Type, Type,
TypeVar, TypeVar,
Union,
) )
from typing import get_args as _get_args from typing import get_args as _get_args
from typing import get_origin as _get_origin from typing import get_origin as _get_origin
@ -17,7 +21,9 @@ from typing import get_origin as _get_origin
T = TypeVar("T") T = TypeVar("T")
def _resolve_type(arg, factory_cls: type): def _resolve_type(
arg: Union[Type, str, ForwardRef], factory_cls: type
) -> Optional[Type]:
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class).""" """Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
if not isinstance(arg, (str, ForwardRef)): if not isinstance(arg, (str, ForwardRef)):
return arg return arg
@ -129,7 +135,7 @@ class BaseFactory(ABC, Generic[T]):
return entry return entry
@classmethod @classmethod
def list_registered(cls) -> list: def list_registered(cls) -> List[str]:
"""List all registered component names.""" """List all registered component names."""
return sorted(cls._entries) return sorted(cls._entries)

View File

@ -29,6 +29,7 @@ class BaseSamplingStrategy(ABC):
Returns: Returns:
Transformed logits tensor. Transformed logits tensor.
""" """
raise NotImplementedError
class TemperatureStrategy(BaseSamplingStrategy): class TemperatureStrategy(BaseSamplingStrategy):
@ -41,7 +42,7 @@ class TemperatureStrategy(BaseSamplingStrategy):
def __init__(self, temperature: Union[float, Tensor] = 1.0): def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
t = self.temperature t = self.temperature
if isinstance(t, Tensor): if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1) t = t.to(logits.device, non_blocking=True).view(-1, 1)
@ -63,7 +64,7 @@ class TopKStrategy(BaseSamplingStrategy):
def __init__(self, top_k: Union[int, Tensor] = 0): def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
tk = self.top_k tk = self.top_k
if isinstance(tk, Tensor): if isinstance(tk, Tensor):
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0) tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
@ -100,7 +101,9 @@ class TopPStrategy(BaseSamplingStrategy):
def __init__(self, top_p: Union[float, Tensor] = 1.0): def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p self.top_p = top_p
def _apply(self, logits, top_p, filter_value): def _apply(
self, logits: Tensor, top_p: Union[float, Tensor], filter_value: float
) -> Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p remove = cum_probs > top_p
@ -111,7 +114,7 @@ class TopPStrategy(BaseSamplingStrategy):
logits[mask] = filter_value logits[mask] = filter_value
return logits return logits
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
tp = self.top_p tp = self.top_p
if isinstance(tp, Tensor): if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True) tp = tp.to(logits.device, non_blocking=True)
@ -142,7 +145,7 @@ class SamplingPipeline(BaseSamplingStrategy):
def __init__(self, strategies: List[BaseSamplingStrategy]): def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
for strategy in self.strategies: for strategy in self.strategies:
logits = strategy.apply(logits, filter_value) logits = strategy.apply(logits, filter_value)
return logits return logits

View File

@ -12,7 +12,7 @@ from typing import Dict, List, Tuple
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
def _truncate(seq: list, max_len: int, mode: str) -> list: def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]:
if len(seq) <= max_len: if len(seq) <= max_len:
return seq return seq
if mode == "keep_end": if mode == "keep_end":
@ -26,10 +26,11 @@ class PackingStrategy(ABC):
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
keys: Dict[str, List[list]], keys: Dict[str, List[List[int]]],
max_packed_len: int, max_packed_len: int,
truncation_mode: str, truncation_mode: str,
) -> Dict[str, List[list]]: ... ) -> Dict[str, List[List[int]]]:
raise NotImplementedError
class PackingStrategyFactory(BaseFactory["PackingStrategy"]): class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@ -38,7 +39,12 @@ class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@PackingStrategyFactory.register("simple") @PackingStrategyFactory.register("simple")
class SimplePacking(PackingStrategy): class SimplePacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode): def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
return { return {
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals] k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
for k, vals in keys.items() for k, vals in keys.items()
@ -47,7 +53,12 @@ class SimplePacking(PackingStrategy):
@PackingStrategyFactory.register("bfd") @PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy): class BFDPacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode): def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
sequences = keys.get("sequence", []) sequences = keys.get("sequence", [])
if not sequences: if not sequences:
return keys return keys
@ -61,7 +72,7 @@ class BFDPacking(PackingStrategy):
return dict(reordered) return dict(reordered)
@staticmethod @staticmethod
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]: def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]:
n = len(sequences) n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = [] bins: List[List[int]] = []

View File

@ -16,7 +16,8 @@ class PositionIdStrategy(ABC):
"""Generate ``position_ids`` for packed sequences.""" """Generate ``position_ids`` for packed sequences."""
@abstractmethod @abstractmethod
def generate(self, sequences: List[list]) -> List[int]: ... def generate(self, sequences: List[List[int]]) -> List[int]:
raise NotImplementedError
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]): class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
@ -25,13 +26,13 @@ class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
@PositionIdStrategyFactory.register("none") @PositionIdStrategyFactory.register("none")
class NoPositionId(PositionIdStrategy): class NoPositionId(PositionIdStrategy):
def generate(self, sequences): def generate(self, sequences: List[List[int]]) -> List[int]:
return [] return []
@PositionIdStrategyFactory.register("doc_reset") @PositionIdStrategyFactory.register("doc_reset")
class DocResetPositionId(PositionIdStrategy): class DocResetPositionId(PositionIdStrategy):
def generate(self, sequences): def generate(self, sequences: List[List[int]]) -> List[int]:
pos_ids = [] pos_ids = []
for seq in sequences: for seq in sequences:
pos_ids.extend(range(len(seq))) pos_ids.extend(range(len(seq)))
@ -40,6 +41,6 @@ class DocResetPositionId(PositionIdStrategy):
@PositionIdStrategyFactory.register("continuous") @PositionIdStrategyFactory.register("continuous")
class ContinuousPositionId(PositionIdStrategy): class ContinuousPositionId(PositionIdStrategy):
def generate(self, sequences): def generate(self, sequences: List[List[int]]) -> List[int]:
total = sum(len(seq) for seq in sequences) total = sum(len(seq) for seq in sequences)
return list(range(total)) return list(range(total))

View File

@ -1,7 +1,7 @@
"""Training strategy implementations with factory pattern.""" """Training strategy implementations with factory pattern."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union from typing import Callable, Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -11,7 +11,9 @@ from torch import Tensor
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
def create_ref_model(model_fn, state_dict: dict) -> nn.Module: def create_ref_model(
model_fn: Callable[[], nn.Module], state_dict: Dict[str, Tensor]
) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict.""" """Create a frozen reference model from model_fn + full state dict."""
ref_model = model_fn() ref_model = model_fn()
ref_model.load_state_dict(state_dict) ref_model.load_state_dict(state_dict)
@ -20,7 +22,7 @@ def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
return ref_model return ref_model
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any: def move_to_device(batch: Dict[str, Tensor], device: str) -> Dict[str, Tensor]:
"""Move batch tensors to specified device with non-blocking transfer.""" """Move batch tensors to specified device with non-blocking transfer."""
return {key: value.to(device, non_blocking=True) for key, value in batch.items()} return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
@ -30,7 +32,7 @@ def get_logprobs(
input_ids: Tensor, input_ids: Tensor,
mask: Tensor, mask: Tensor,
reduction: str, reduction: str,
): ) -> Tensor:
"""Compute token-wise log probabilities from model outputs. """Compute token-wise log probabilities from model outputs.
Args: Args:
@ -88,7 +90,10 @@ class BaseStrategy(ABC):
"""Abstract base class for training strategies.""" """Abstract base class for training strategies."""
def __init__( def __init__(
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
**kwargs,
): ):
self.model = model self.model = model
self.device = device self.device = device
@ -139,7 +144,13 @@ class SEQStrategy(BaseStrategy):
Computes cross-entropy loss for next token prediction. Computes cross-entropy loss for next token prediction.
""" """
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): def __init__(
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
label_smoothing: float = 0.0,
**kwargs,
):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
@ -164,7 +175,13 @@ class SFTStrategy(BaseStrategy):
Applies cross-entropy loss only to tokens where loss_mask is True. Applies cross-entropy loss only to tokens where loss_mask is True.
""" """
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): def __init__(
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
label_smoothing: float = 0.0,
**kwargs,
):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing

View File

@ -154,11 +154,13 @@ class CheckpointCallback(TrainCallback):
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}" self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
) )
extra = self.save_extra_fn(context) extra = self.save_extra_fn(context)
meta = context.config.to_dict()
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, state_dict=state_dict,
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration, iteration=context.iteration,
extra=extra, extra=extra,
meta=meta,
config=context.model_config, config=context.model_config,
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional, Self from typing import Any, Dict, Optional, Self
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -36,7 +36,7 @@ class TrainContext:
world_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
kwargs: dict = field(default_factory=dict) kwargs: Dict[str, Any] = field(default_factory=dict)
class TrainContextBuilder: class TrainContextBuilder: