diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 1421219..d9086de 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, fields -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch.nn as nn from torch.optim import Optimizer @@ -40,7 +40,7 @@ class TrainConfig(BaseConfig): max_grad_norm: float = field( default=1.0, metadata={"help": "Maximum gradient norm."} ) - gradient_checkpointing_modules: list = field( + gradient_checkpointing_modules: List[str] = field( default_factory=list, metadata={"help": "Module types to enable activation checkpointing for."}, ) @@ -133,11 +133,11 @@ class TrainConfig(BaseConfig): metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."}, ) - executor_kwargs: dict = field( + executor_kwargs: Dict[str, Any] = field( default_factory=dict, metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."}, ) - extra_kwargs: dict = field( + extra_kwargs: Dict[str, Any] = field( default_factory=dict, metadata={"help": "Other arguments."} ) diff --git a/astrai/factory.py b/astrai/factory.py index 8c27955..f411010 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -4,12 +4,16 @@ import inspect import sys from abc import ABC from typing import ( + Any, Callable, Dict, ForwardRef, Generic, + List, + Optional, Type, TypeVar, + Union, ) from typing import get_args as _get_args from typing import get_origin as _get_origin @@ -17,7 +21,9 @@ from typing import get_origin as _get_origin T = TypeVar("T") -def _resolve_type(arg, factory_cls: type): +def _resolve_type( + arg: Union[Type, str, ForwardRef], factory_cls: type +) -> Optional[Type]: """Resolve a generic type-arg (str forward-ref, ForwardRef, or class).""" if not isinstance(arg, (str, ForwardRef)): return arg @@ -129,7 +135,7 @@ class BaseFactory(ABC, Generic[T]): return entry @classmethod - def list_registered(cls) -> list: + def list_registered(cls) -> List[str]: """List all registered component names.""" return sorted(cls._entries) diff --git a/astrai/inference/sample.py b/astrai/inference/sample.py index cb007df..b66099f 100644 --- a/astrai/inference/sample.py +++ b/astrai/inference/sample.py @@ -29,6 +29,7 @@ class BaseSamplingStrategy(ABC): Returns: Transformed logits tensor. """ + raise NotImplementedError class TemperatureStrategy(BaseSamplingStrategy): @@ -41,7 +42,7 @@ class TemperatureStrategy(BaseSamplingStrategy): def __init__(self, temperature: Union[float, Tensor] = 1.0): self.temperature = temperature - def apply(self, logits, filter_value=-float("inf")): + def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: t = self.temperature if isinstance(t, Tensor): t = t.to(logits.device, non_blocking=True).view(-1, 1) @@ -63,7 +64,7 @@ class TopKStrategy(BaseSamplingStrategy): def __init__(self, top_k: Union[int, Tensor] = 0): self.top_k = top_k - def apply(self, logits, filter_value=-float("inf")): + def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: tk = self.top_k if isinstance(tk, Tensor): tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0) @@ -100,7 +101,9 @@ class TopPStrategy(BaseSamplingStrategy): def __init__(self, top_p: Union[float, Tensor] = 1.0): self.top_p = top_p - def _apply(self, logits, top_p, filter_value): + def _apply( + self, logits: Tensor, top_p: Union[float, Tensor], filter_value: float + ) -> Tensor: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) remove = cum_probs > top_p @@ -111,7 +114,7 @@ class TopPStrategy(BaseSamplingStrategy): logits[mask] = filter_value return logits - def apply(self, logits, filter_value=-float("inf")): + def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: tp = self.top_p if isinstance(tp, Tensor): tp = tp.to(logits.device, non_blocking=True) @@ -142,7 +145,7 @@ class SamplingPipeline(BaseSamplingStrategy): def __init__(self, strategies: List[BaseSamplingStrategy]): self.strategies = strategies - def apply(self, logits, filter_value=-float("inf")): + def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: for strategy in self.strategies: logits = strategy.apply(logits, filter_value) return logits diff --git a/astrai/preprocessing/packing.py b/astrai/preprocessing/packing.py index 7f1a663..035a546 100644 --- a/astrai/preprocessing/packing.py +++ b/astrai/preprocessing/packing.py @@ -12,7 +12,7 @@ from typing import Dict, List, Tuple from astrai.factory import BaseFactory -def _truncate(seq: list, max_len: int, mode: str) -> list: +def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]: if len(seq) <= max_len: return seq if mode == "keep_end": @@ -26,10 +26,11 @@ class PackingStrategy(ABC): @abstractmethod def apply( self, - keys: Dict[str, List[list]], + keys: Dict[str, List[List[int]]], max_packed_len: int, truncation_mode: str, - ) -> Dict[str, List[list]]: ... + ) -> Dict[str, List[List[int]]]: + raise NotImplementedError class PackingStrategyFactory(BaseFactory["PackingStrategy"]): @@ -38,7 +39,12 @@ class PackingStrategyFactory(BaseFactory["PackingStrategy"]): @PackingStrategyFactory.register("simple") class SimplePacking(PackingStrategy): - def apply(self, keys, max_packed_len, truncation_mode): + def apply( + self, + keys: Dict[str, List[List[int]]], + max_packed_len: int, + truncation_mode: str, + ) -> Dict[str, List[List[int]]]: return { k: [_truncate(v, max_packed_len, truncation_mode) for v in vals] for k, vals in keys.items() @@ -47,7 +53,12 @@ class SimplePacking(PackingStrategy): @PackingStrategyFactory.register("bfd") class BFDPacking(PackingStrategy): - def apply(self, keys, max_packed_len, truncation_mode): + def apply( + self, + keys: Dict[str, List[List[int]]], + max_packed_len: int, + truncation_mode: str, + ) -> Dict[str, List[List[int]]]: sequences = keys.get("sequence", []) if not sequences: return keys @@ -61,7 +72,7 @@ class BFDPacking(PackingStrategy): return dict(reordered) @staticmethod - def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]: + def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]: n = len(sequences) order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) bins: List[List[int]] = [] diff --git a/astrai/preprocessing/position_id.py b/astrai/preprocessing/position_id.py index 08d1ddf..4c9425a 100644 --- a/astrai/preprocessing/position_id.py +++ b/astrai/preprocessing/position_id.py @@ -16,7 +16,8 @@ class PositionIdStrategy(ABC): """Generate ``position_ids`` for packed sequences.""" @abstractmethod - def generate(self, sequences: List[list]) -> List[int]: ... + def generate(self, sequences: List[List[int]]) -> List[int]: + raise NotImplementedError class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]): @@ -25,13 +26,13 @@ class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]): @PositionIdStrategyFactory.register("none") class NoPositionId(PositionIdStrategy): - def generate(self, sequences): + def generate(self, sequences: List[List[int]]) -> List[int]: return [] @PositionIdStrategyFactory.register("doc_reset") class DocResetPositionId(PositionIdStrategy): - def generate(self, sequences): + def generate(self, sequences: List[List[int]]) -> List[int]: pos_ids = [] for seq in sequences: pos_ids.extend(range(len(seq))) @@ -40,6 +41,6 @@ class DocResetPositionId(PositionIdStrategy): @PositionIdStrategyFactory.register("continuous") class ContinuousPositionId(PositionIdStrategy): - def generate(self, sequences): + def generate(self, sequences: List[List[int]]) -> List[int]: total = sum(len(seq) for seq in sequences) return list(range(total)) diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 1eb7e02..aca7e08 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -1,7 +1,7 @@ """Training strategy implementations with factory pattern.""" from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Union +from typing import Callable, Dict, Union import torch import torch.nn as nn @@ -11,7 +11,9 @@ from torch import Tensor from astrai.factory import BaseFactory -def create_ref_model(model_fn, state_dict: dict) -> nn.Module: +def create_ref_model( + model_fn: Callable[[], nn.Module], state_dict: Dict[str, Tensor] +) -> nn.Module: """Create a frozen reference model from model_fn + full state dict.""" ref_model = model_fn() ref_model.load_state_dict(state_dict) @@ -20,7 +22,7 @@ def create_ref_model(model_fn, state_dict: dict) -> nn.Module: return ref_model -def move_to_device(batch: Dict[str, Tensor], device: str) -> Any: +def move_to_device(batch: Dict[str, Tensor], device: str) -> Dict[str, Tensor]: """Move batch tensors to specified device with non-blocking transfer.""" return {key: value.to(device, non_blocking=True) for key, value in batch.items()} @@ -30,7 +32,7 @@ def get_logprobs( input_ids: Tensor, mask: Tensor, reduction: str, -): +) -> Tensor: """Compute token-wise log probabilities from model outputs. Args: @@ -88,7 +90,10 @@ class BaseStrategy(ABC): """Abstract base class for training strategies.""" def __init__( - self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs + self, + model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], + device: str, + **kwargs, ): self.model = model self.device = device @@ -139,7 +144,13 @@ class SEQStrategy(BaseStrategy): Computes cross-entropy loss for next token prediction. """ - def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): + def __init__( + self, + model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], + device: str, + label_smoothing: float = 0.0, + **kwargs, + ): super().__init__(model, device, **kwargs) self.label_smoothing = label_smoothing @@ -164,7 +175,13 @@ class SFTStrategy(BaseStrategy): Applies cross-entropy loss only to tokens where loss_mask is True. """ - def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): + def __init__( + self, + model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], + device: str, + label_smoothing: float = 0.0, + **kwargs, + ): super().__init__(model, device, **kwargs) self.label_smoothing = label_smoothing diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 6aaad95..53895ca 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -154,11 +154,13 @@ class CheckpointCallback(TrainCallback): self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}" ) extra = self.save_extra_fn(context) + meta = context.config.to_dict() context.checkpoint = Checkpoint( state_dict=state_dict, epoch=context.epoch, iteration=context.iteration, extra=extra, + meta=meta, config=context.model_config, ) context.checkpoint.save(save_path) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 71b4000..031ab51 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Self +from typing import Any, Dict, Optional, Self import torch import torch.nn as nn @@ -36,7 +36,7 @@ class TrainContext: world_size: int = field(default=1) rank: int = field(default=0) - kwargs: dict = field(default_factory=dict) + kwargs: Dict[str, Any] = field(default_factory=dict) class TrainContextBuilder: