fix : 修复策略相关文件的类型注解与抽象方法体
- 修复 strategy.py 单元素 Union 与缺失的参数/返回类型注解 - 修复 train_context.py 8 个 default=None 字段缺 Optional 标记 - 修复 sample.py/packing.py/position_id.py 方法缺参数及返回类型注解 - 修复 factory.py _resolve_type/list_registered 缺类型注解 - 修复 train_config.py 裸 dict/list 缺泛型参数 - abstractmethod body 从 ... 改为 raise NotImplementedError - feat : checkpoint meta.json 保存 TrainConfig 超参供人工查阅
This commit is contained in:
parent
a2512f8a5a
commit
fec376b0dd
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field, fields
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -40,7 +40,7 @@ class TrainConfig(BaseConfig):
|
|||
max_grad_norm: float = field(
|
||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
gradient_checkpointing_modules: list = field(
|
||||
gradient_checkpointing_modules: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||
)
|
||||
|
|
@ -133,11 +133,11 @@ class TrainConfig(BaseConfig):
|
|||
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
|
||||
)
|
||||
|
||||
executor_kwargs: dict = field(
|
||||
executor_kwargs: Dict[str, Any] = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||
)
|
||||
extra_kwargs: dict = field(
|
||||
extra_kwargs: Dict[str, Any] = field(
|
||||
default_factory=dict, metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,16 @@ import inspect
|
|||
import sys
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import get_args as _get_args
|
||||
from typing import get_origin as _get_origin
|
||||
|
|
@ -17,7 +21,9 @@ from typing import get_origin as _get_origin
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _resolve_type(arg, factory_cls: type):
|
||||
def _resolve_type(
|
||||
arg: Union[Type, str, ForwardRef], factory_cls: type
|
||||
) -> Optional[Type]:
|
||||
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
|
||||
if not isinstance(arg, (str, ForwardRef)):
|
||||
return arg
|
||||
|
|
@ -129,7 +135,7 @@ class BaseFactory(ABC, Generic[T]):
|
|||
return entry
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
def list_registered(cls) -> List[str]:
|
||||
"""List all registered component names."""
|
||||
return sorted(cls._entries)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class BaseSamplingStrategy(ABC):
|
|||
Returns:
|
||||
Transformed logits tensor.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TemperatureStrategy(BaseSamplingStrategy):
|
||||
|
|
@ -41,7 +42,7 @@ class TemperatureStrategy(BaseSamplingStrategy):
|
|||
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
||||
self.temperature = temperature
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
t = self.temperature
|
||||
if isinstance(t, Tensor):
|
||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||
|
|
@ -63,7 +64,7 @@ class TopKStrategy(BaseSamplingStrategy):
|
|||
def __init__(self, top_k: Union[int, Tensor] = 0):
|
||||
self.top_k = top_k
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
tk = self.top_k
|
||||
if isinstance(tk, Tensor):
|
||||
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
||||
|
|
@ -100,7 +101,9 @@ class TopPStrategy(BaseSamplingStrategy):
|
|||
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
||||
self.top_p = top_p
|
||||
|
||||
def _apply(self, logits, top_p, filter_value):
|
||||
def _apply(
|
||||
self, logits: Tensor, top_p: Union[float, Tensor], filter_value: float
|
||||
) -> Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
remove = cum_probs > top_p
|
||||
|
|
@ -111,7 +114,7 @@ class TopPStrategy(BaseSamplingStrategy):
|
|||
logits[mask] = filter_value
|
||||
return logits
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
tp = self.top_p
|
||||
if isinstance(tp, Tensor):
|
||||
tp = tp.to(logits.device, non_blocking=True)
|
||||
|
|
@ -142,7 +145,7 @@ class SamplingPipeline(BaseSamplingStrategy):
|
|||
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||
self.strategies = strategies
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
for strategy in self.strategies:
|
||||
logits = strategy.apply(logits, filter_value)
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Dict, List, Tuple
|
|||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||
def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]:
|
||||
if len(seq) <= max_len:
|
||||
return seq
|
||||
if mode == "keep_end":
|
||||
|
|
@ -26,10 +26,11 @@ class PackingStrategy(ABC):
|
|||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
keys: Dict[str, List[list]],
|
||||
keys: Dict[str, List[List[int]]],
|
||||
max_packed_len: int,
|
||||
truncation_mode: str,
|
||||
) -> Dict[str, List[list]]: ...
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
||||
|
|
@ -38,7 +39,12 @@ class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
|||
|
||||
@PackingStrategyFactory.register("simple")
|
||||
class SimplePacking(PackingStrategy):
|
||||
def apply(self, keys, max_packed_len, truncation_mode):
|
||||
def apply(
|
||||
self,
|
||||
keys: Dict[str, List[List[int]]],
|
||||
max_packed_len: int,
|
||||
truncation_mode: str,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
return {
|
||||
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
|
||||
for k, vals in keys.items()
|
||||
|
|
@ -47,7 +53,12 @@ class SimplePacking(PackingStrategy):
|
|||
|
||||
@PackingStrategyFactory.register("bfd")
|
||||
class BFDPacking(PackingStrategy):
|
||||
def apply(self, keys, max_packed_len, truncation_mode):
|
||||
def apply(
|
||||
self,
|
||||
keys: Dict[str, List[List[int]]],
|
||||
max_packed_len: int,
|
||||
truncation_mode: str,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
sequences = keys.get("sequence", [])
|
||||
if not sequences:
|
||||
return keys
|
||||
|
|
@ -61,7 +72,7 @@ class BFDPacking(PackingStrategy):
|
|||
return dict(reordered)
|
||||
|
||||
@staticmethod
|
||||
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
|
||||
def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]:
|
||||
n = len(sequences)
|
||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||
bins: List[List[int]] = []
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ class PositionIdStrategy(ABC):
|
|||
"""Generate ``position_ids`` for packed sequences."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, sequences: List[list]) -> List[int]: ...
|
||||
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
||||
|
|
@ -25,13 +26,13 @@ class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
|||
|
||||
@PositionIdStrategyFactory.register("none")
|
||||
class NoPositionId(PositionIdStrategy):
|
||||
def generate(self, sequences):
|
||||
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||
return []
|
||||
|
||||
|
||||
@PositionIdStrategyFactory.register("doc_reset")
|
||||
class DocResetPositionId(PositionIdStrategy):
|
||||
def generate(self, sequences):
|
||||
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||
pos_ids = []
|
||||
for seq in sequences:
|
||||
pos_ids.extend(range(len(seq)))
|
||||
|
|
@ -40,6 +41,6 @@ class DocResetPositionId(PositionIdStrategy):
|
|||
|
||||
@PositionIdStrategyFactory.register("continuous")
|
||||
class ContinuousPositionId(PositionIdStrategy):
|
||||
def generate(self, sequences):
|
||||
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||
total = sum(len(seq) for seq in sequences)
|
||||
return list(range(total))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Training strategy implementations with factory pattern."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Union
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -11,7 +11,9 @@ from torch import Tensor
|
|||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||
def create_ref_model(
|
||||
model_fn: Callable[[], nn.Module], state_dict: Dict[str, Tensor]
|
||||
) -> nn.Module:
|
||||
"""Create a frozen reference model from model_fn + full state dict."""
|
||||
ref_model = model_fn()
|
||||
ref_model.load_state_dict(state_dict)
|
||||
|
|
@ -20,7 +22,7 @@ def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
|||
return ref_model
|
||||
|
||||
|
||||
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
||||
def move_to_device(batch: Dict[str, Tensor], device: str) -> Dict[str, Tensor]:
|
||||
"""Move batch tensors to specified device with non-blocking transfer."""
|
||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||
|
||||
|
|
@ -30,7 +32,7 @@ def get_logprobs(
|
|||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
reduction: str,
|
||||
):
|
||||
) -> Tensor:
|
||||
"""Compute token-wise log probabilities from model outputs.
|
||||
|
||||
Args:
|
||||
|
|
@ -88,7 +90,10 @@ class BaseStrategy(ABC):
|
|||
"""Abstract base class for training strategies."""
|
||||
|
||||
def __init__(
|
||||
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
||||
self,
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
device: str,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
|
@ -139,7 +144,13 @@ class SEQStrategy(BaseStrategy):
|
|||
Computes cross-entropy loss for next token prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
device: str,
|
||||
label_smoothing: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
|
|
@ -164,7 +175,13 @@ class SFTStrategy(BaseStrategy):
|
|||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
device: str,
|
||||
label_smoothing: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
|
|
|
|||
|
|
@ -154,11 +154,13 @@ class CheckpointCallback(TrainCallback):
|
|||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
extra = self.save_extra_fn(context)
|
||||
meta = context.config.to_dict()
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
meta=meta,
|
||||
config=context.model_config,
|
||||
)
|
||||
context.checkpoint.save(save_path)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Self
|
||||
from typing import Any, Dict, Optional, Self
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -36,7 +36,7 @@ class TrainContext:
|
|||
|
||||
world_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class TrainContextBuilder:
|
||||
|
|
|
|||
Loading…
Reference in New Issue