fix : 修复策略相关文件的类型注解与抽象方法体
- 修复 strategy.py 单元素 Union 与缺失的参数/返回类型注解 - 修复 train_context.py 8 个 default=None 字段缺 Optional 标记 - 修复 sample.py/packing.py/position_id.py 方法缺参数及返回类型注解 - 修复 factory.py _resolve_type/list_registered 缺类型注解 - 修复 train_config.py 裸 dict/list 缺泛型参数 - abstractmethod body 从 ... 改为 raise NotImplementedError - feat : checkpoint meta.json 保存 TrainConfig 超参供人工查阅
This commit is contained in:
parent
a2512f8a5a
commit
fec376b0dd
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Callable, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -40,7 +40,7 @@ class TrainConfig(BaseConfig):
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||||
)
|
)
|
||||||
gradient_checkpointing_modules: list = field(
|
gradient_checkpointing_modules: List[str] = field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
metadata={"help": "Module types to enable activation checkpointing for."},
|
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||||
)
|
)
|
||||||
|
|
@ -133,11 +133,11 @@ class TrainConfig(BaseConfig):
|
||||||
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
|
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
|
||||||
)
|
)
|
||||||
|
|
||||||
executor_kwargs: dict = field(
|
executor_kwargs: Dict[str, Any] = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||||
)
|
)
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: Dict[str, Any] = field(
|
||||||
default_factory=dict, metadata={"help": "Other arguments."}
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,16 @@ import inspect
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
ForwardRef,
|
ForwardRef,
|
||||||
Generic,
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
from typing import get_args as _get_args
|
from typing import get_args as _get_args
|
||||||
from typing import get_origin as _get_origin
|
from typing import get_origin as _get_origin
|
||||||
|
|
@ -17,7 +21,9 @@ from typing import get_origin as _get_origin
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_type(arg, factory_cls: type):
|
def _resolve_type(
|
||||||
|
arg: Union[Type, str, ForwardRef], factory_cls: type
|
||||||
|
) -> Optional[Type]:
|
||||||
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
|
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
|
||||||
if not isinstance(arg, (str, ForwardRef)):
|
if not isinstance(arg, (str, ForwardRef)):
|
||||||
return arg
|
return arg
|
||||||
|
|
@ -129,7 +135,7 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_registered(cls) -> list:
|
def list_registered(cls) -> List[str]:
|
||||||
"""List all registered component names."""
|
"""List all registered component names."""
|
||||||
return sorted(cls._entries)
|
return sorted(cls._entries)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ class BaseSamplingStrategy(ABC):
|
||||||
Returns:
|
Returns:
|
||||||
Transformed logits tensor.
|
Transformed logits tensor.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class TemperatureStrategy(BaseSamplingStrategy):
|
class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
|
|
@ -41,7 +42,7 @@ class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
t = self.temperature
|
t = self.temperature
|
||||||
if isinstance(t, Tensor):
|
if isinstance(t, Tensor):
|
||||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||||
|
|
@ -63,7 +64,7 @@ class TopKStrategy(BaseSamplingStrategy):
|
||||||
def __init__(self, top_k: Union[int, Tensor] = 0):
|
def __init__(self, top_k: Union[int, Tensor] = 0):
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
tk = self.top_k
|
tk = self.top_k
|
||||||
if isinstance(tk, Tensor):
|
if isinstance(tk, Tensor):
|
||||||
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
||||||
|
|
@ -100,7 +101,9 @@ class TopPStrategy(BaseSamplingStrategy):
|
||||||
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
|
|
||||||
def _apply(self, logits, top_p, filter_value):
|
def _apply(
|
||||||
|
self, logits: Tensor, top_p: Union[float, Tensor], filter_value: float
|
||||||
|
) -> Tensor:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
remove = cum_probs > top_p
|
remove = cum_probs > top_p
|
||||||
|
|
@ -111,7 +114,7 @@ class TopPStrategy(BaseSamplingStrategy):
|
||||||
logits[mask] = filter_value
|
logits[mask] = filter_value
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
tp = self.top_p
|
tp = self.top_p
|
||||||
if isinstance(tp, Tensor):
|
if isinstance(tp, Tensor):
|
||||||
tp = tp.to(logits.device, non_blocking=True)
|
tp = tp.to(logits.device, non_blocking=True)
|
||||||
|
|
@ -142,7 +145,7 @@ class SamplingPipeline(BaseSamplingStrategy):
|
||||||
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||||
self.strategies = strategies
|
self.strategies = strategies
|
||||||
|
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
for strategy in self.strategies:
|
for strategy in self.strategies:
|
||||||
logits = strategy.apply(logits, filter_value)
|
logits = strategy.apply(logits, filter_value)
|
||||||
return logits
|
return logits
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from typing import Dict, List, Tuple
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]:
|
||||||
if len(seq) <= max_len:
|
if len(seq) <= max_len:
|
||||||
return seq
|
return seq
|
||||||
if mode == "keep_end":
|
if mode == "keep_end":
|
||||||
|
|
@ -26,10 +26,11 @@ class PackingStrategy(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
keys: Dict[str, List[list]],
|
keys: Dict[str, List[List[int]]],
|
||||||
max_packed_len: int,
|
max_packed_len: int,
|
||||||
truncation_mode: str,
|
truncation_mode: str,
|
||||||
) -> Dict[str, List[list]]: ...
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
||||||
|
|
@ -38,7 +39,12 @@ class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
||||||
|
|
||||||
@PackingStrategyFactory.register("simple")
|
@PackingStrategyFactory.register("simple")
|
||||||
class SimplePacking(PackingStrategy):
|
class SimplePacking(PackingStrategy):
|
||||||
def apply(self, keys, max_packed_len, truncation_mode):
|
def apply(
|
||||||
|
self,
|
||||||
|
keys: Dict[str, List[List[int]]],
|
||||||
|
max_packed_len: int,
|
||||||
|
truncation_mode: str,
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
return {
|
return {
|
||||||
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
|
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
|
||||||
for k, vals in keys.items()
|
for k, vals in keys.items()
|
||||||
|
|
@ -47,7 +53,12 @@ class SimplePacking(PackingStrategy):
|
||||||
|
|
||||||
@PackingStrategyFactory.register("bfd")
|
@PackingStrategyFactory.register("bfd")
|
||||||
class BFDPacking(PackingStrategy):
|
class BFDPacking(PackingStrategy):
|
||||||
def apply(self, keys, max_packed_len, truncation_mode):
|
def apply(
|
||||||
|
self,
|
||||||
|
keys: Dict[str, List[List[int]]],
|
||||||
|
max_packed_len: int,
|
||||||
|
truncation_mode: str,
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
sequences = keys.get("sequence", [])
|
sequences = keys.get("sequence", [])
|
||||||
if not sequences:
|
if not sequences:
|
||||||
return keys
|
return keys
|
||||||
|
|
@ -61,7 +72,7 @@ class BFDPacking(PackingStrategy):
|
||||||
return dict(reordered)
|
return dict(reordered)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
|
def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]:
|
||||||
n = len(sequences)
|
n = len(sequences)
|
||||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||||
bins: List[List[int]] = []
|
bins: List[List[int]] = []
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,8 @@ class PositionIdStrategy(ABC):
|
||||||
"""Generate ``position_ids`` for packed sequences."""
|
"""Generate ``position_ids`` for packed sequences."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self, sequences: List[list]) -> List[int]: ...
|
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
||||||
|
|
@ -25,13 +26,13 @@ class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
||||||
|
|
||||||
@PositionIdStrategyFactory.register("none")
|
@PositionIdStrategyFactory.register("none")
|
||||||
class NoPositionId(PositionIdStrategy):
|
class NoPositionId(PositionIdStrategy):
|
||||||
def generate(self, sequences):
|
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@PositionIdStrategyFactory.register("doc_reset")
|
@PositionIdStrategyFactory.register("doc_reset")
|
||||||
class DocResetPositionId(PositionIdStrategy):
|
class DocResetPositionId(PositionIdStrategy):
|
||||||
def generate(self, sequences):
|
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||||
pos_ids = []
|
pos_ids = []
|
||||||
for seq in sequences:
|
for seq in sequences:
|
||||||
pos_ids.extend(range(len(seq)))
|
pos_ids.extend(range(len(seq)))
|
||||||
|
|
@ -40,6 +41,6 @@ class DocResetPositionId(PositionIdStrategy):
|
||||||
|
|
||||||
@PositionIdStrategyFactory.register("continuous")
|
@PositionIdStrategyFactory.register("continuous")
|
||||||
class ContinuousPositionId(PositionIdStrategy):
|
class ContinuousPositionId(PositionIdStrategy):
|
||||||
def generate(self, sequences):
|
def generate(self, sequences: List[List[int]]) -> List[int]:
|
||||||
total = sum(len(seq) for seq in sequences)
|
total = sum(len(seq) for seq in sequences)
|
||||||
return list(range(total))
|
return list(range(total))
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Training strategy implementations with factory pattern."""
|
"""Training strategy implementations with factory pattern."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Callable, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -11,7 +11,9 @@ from torch import Tensor
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
def create_ref_model(
|
||||||
|
model_fn: Callable[[], nn.Module], state_dict: Dict[str, Tensor]
|
||||||
|
) -> nn.Module:
|
||||||
"""Create a frozen reference model from model_fn + full state dict."""
|
"""Create a frozen reference model from model_fn + full state dict."""
|
||||||
ref_model = model_fn()
|
ref_model = model_fn()
|
||||||
ref_model.load_state_dict(state_dict)
|
ref_model.load_state_dict(state_dict)
|
||||||
|
|
@ -20,7 +22,7 @@ def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
||||||
|
|
||||||
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
def move_to_device(batch: Dict[str, Tensor], device: str) -> Dict[str, Tensor]:
|
||||||
"""Move batch tensors to specified device with non-blocking transfer."""
|
"""Move batch tensors to specified device with non-blocking transfer."""
|
||||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||||
|
|
||||||
|
|
@ -30,7 +32,7 @@ def get_logprobs(
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
reduction: str,
|
reduction: str,
|
||||||
):
|
) -> Tensor:
|
||||||
"""Compute token-wise log probabilities from model outputs.
|
"""Compute token-wise log probabilities from model outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -88,7 +90,10 @@ class BaseStrategy(ABC):
|
||||||
"""Abstract base class for training strategies."""
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
self,
|
||||||
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
|
device: str,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
@ -139,7 +144,13 @@ class SEQStrategy(BaseStrategy):
|
||||||
Computes cross-entropy loss for next token prediction.
|
Computes cross-entropy loss for next token prediction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
|
device: str,
|
||||||
|
label_smoothing: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
|
@ -164,7 +175,13 @@ class SFTStrategy(BaseStrategy):
|
||||||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
|
device: str,
|
||||||
|
label_smoothing: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -154,11 +154,13 @@ class CheckpointCallback(TrainCallback):
|
||||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||||
)
|
)
|
||||||
extra = self.save_extra_fn(context)
|
extra = self.save_extra_fn(context)
|
||||||
|
meta = context.config.to_dict()
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
iteration=context.iteration,
|
iteration=context.iteration,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
|
meta=meta,
|
||||||
config=context.model_config,
|
config=context.model_config,
|
||||||
)
|
)
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Self
|
from typing import Any, Dict, Optional, Self
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -36,7 +36,7 @@ class TrainContext:
|
||||||
|
|
||||||
world_size: int = field(default=1)
|
world_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue