diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 45c51b5..d6c3769 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -138,13 +138,13 @@ class ProtocolHandler: yielded = "" matched = None async for token in agen: - ctx.completion_tokens += 1 body += token matched = checker.check(body) if matched: break + ctx.completion_tokens += 1 yield self.builder.format_chunk(token) yielded += token @@ -168,7 +168,6 @@ class ProtocolHandler: matched = None async for token in agen: - ctx.completion_tokens += 1 chunks.append(token) body += token @@ -176,6 +175,8 @@ class ProtocolHandler: if matched: break + ctx.completion_tokens += 1 + content = "".join(chunks) stop = StopInfo(matched=matched, body=body) return self.builder.format_response(ctx, content, stop) diff --git a/astrai/inference/core/task.py b/astrai/inference/core/task.py index 1b449c8..5fcf0a4 100644 --- a/astrai/inference/core/task.py +++ b/astrai/inference/core/task.py @@ -186,7 +186,10 @@ class TaskManager: return bool(self.active_tasks or self.waiting_queue) def wait_for_tasks(self, timeout: float = 1.0): - self._task_event.clear() + with self._lock: + if self.waiting_queue or self.active_tasks: + return + self._task_event.clear() self._task_event.wait(timeout=timeout) def get_active_tasks(self) -> List[Task]: diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index c1f2141..566d0d7 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from torch.distributed.fsdp import FullStateDictConfig, StateDictType from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -115,8 +116,8 @@ class BaseExecutor: def backward(self, loss: torch.Tensor): loss.backward() - def unwrap_model(self, model: nn.Module) -> nn.Module: - return model + def unwrap_model(self, model: nn.Module): + return model.state_dict() @property def use_distributed(self) -> bool: @@ -195,10 +196,10 @@ class DDPExecutor(BaseExecutor): return model.no_sync() return contextlib.nullcontext() - def unwrap_model(self, model: nn.Module) -> nn.Module: + def unwrap_model(self, model: nn.Module): if isinstance(model, DDP): - return model.module - return model + return model.module.state_dict() + return model.state_dict() @ExecutorFactory.register("fsdp") @@ -259,9 +260,13 @@ class FSDPExecutor(BaseExecutor): return model.no_sync() return contextlib.nullcontext() - def unwrap_model(self, model: nn.Module) -> nn.Module: - if self._original_model is not None: - return self._original_model - if isinstance(model, FSDP): - return model._fsdp_wrapped_module - return model + def unwrap_model(self, model: nn.Module): + if isinstance(model, FSDP) and self.use_distributed: + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=False), + ): + return model.state_dict() + + return model.state_dict() diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 37ee0a8..bdf1538 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -1,6 +1,5 @@ """Training strategy implementations with factory pattern.""" -import copy from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Union @@ -8,28 +7,14 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.nn.parallel import DistributedDataParallel as DDP from astrai.factory import BaseFactory -def unwrap_model(model: nn.Module) -> nn.Module: - if isinstance(model, DDP): - return model.module - if isinstance(model, FSDP): - return model._fsdp_wrapped_module - return model - - -def create_ref_model(model: nn.Module) -> nn.Module: - """Create a reference model for DPO/GRPO training. - - Handles DDP-wrapped models safely by unwrapping first, - then creating a deep copy with frozen gradients. - """ - original_model = unwrap_model(model) - ref_model = copy.deepcopy(original_model) +def create_ref_model(model_fn, state_dict: dict) -> nn.Module: + """Create a frozen reference model from model_fn + full state dict.""" + ref_model = model_fn() + ref_model.load_state_dict(state_dict) ref_model.requires_grad_(False) ref_model.eval() return ref_model @@ -91,6 +76,8 @@ class BaseStrategy(ABC): ): self.model = model self.device = device + self.executor = kwargs.pop("executor", None) + self.model_fn = kwargs.pop("model_fn", None) self.extra_kwargs = kwargs @abstractmethod @@ -230,7 +217,9 @@ class DPOStrategy(BaseStrategy): **kwargs, ): super().__init__(model, device, **kwargs) - self.ref_model = create_ref_model(model) + self.ref_model = create_ref_model( + self.model_fn, self.executor.unwrap_model(model) + ) self.beta = beta self.reduction = reduction @@ -284,7 +273,9 @@ class GRPOStrategy(BaseStrategy): **kwargs, ): super().__init__(model, device, **kwargs) - self.ref_model = create_ref_model(model) + self.ref_model = create_ref_model( + self.model_fn, self.executor.unwrap_model(model) + ) self.clip_eps = clip_eps self.kl_coef = kl_coef self.group_size = group_size @@ -294,8 +285,7 @@ class GRPOStrategy(BaseStrategy): def sync_ref_model(self): """Copy current model weights to ref model.""" - ref_state = self.model.state_dict() - self.ref_model.load_state_dict(ref_state) + self.ref_model.load_state_dict(self.executor.unwrap_model(self.model)) def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: self._step += 1 diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 31f2260..225e4d9 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -146,8 +146,7 @@ class CheckpointCallback(TrainCallback): self.last_ckpt_iter = 0 def _save_checkpoint(self, context: TrainContext): - unwrapped = context.executor.unwrap_model(context.model) - state_dict = unwrapped.state_dict() + state_dict = context.executor.unwrap_model(context.model) self.last_ckpt_iter = context.iteration if get_rank() == 0: diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 879830b..9d268e1 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -162,6 +162,8 @@ class TrainContextBuilder: model=context.model, train_type=cfg.strategy, device=device, + executor=executor, + model_fn=cfg.model_fn, **cfg.extra_kwargs, ) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index b9c8cff..21662c8 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,4 +1,3 @@ -import json import os import numpy as np @@ -8,7 +7,6 @@ import torch from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.storage import ( H5Store, - MmapStore, StoreFactory, detect_format, load_bin,