From d4451f6afbf084f6a369e0d2c8d3dcfdcee11cf8 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 29 May 2026 21:12:24 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20=E5=B9=B6=E8=A1=8C=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=20state=5Fdict=20=E6=94=B6=E9=9B=86=E4=B8=8E=E8=AE=AD=E7=BB=83?= =?UTF-8?q?/=E6=8E=A8=E7=90=86=E5=B9=B6=E5=8F=91=E7=BC=BA=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True - DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict() - CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict - strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数 - TrainContextBuilder: 传递 model_fn + executor 到 strategy - GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重 - TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态 - ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式) --- astrai/inference/api/protocol.py | 5 +++-- astrai/inference/core/task.py | 5 ++++- astrai/parallel/executor.py | 27 ++++++++++++++---------- astrai/trainer/strategy.py | 36 ++++++++++++-------------------- astrai/trainer/train_callback.py | 3 +-- astrai/trainer/train_context.py | 2 ++ tests/data/test_dataset.py | 2 -- 7 files changed, 39 insertions(+), 41 deletions(-) diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 45c51b5..d6c3769 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -138,13 +138,13 @@ class ProtocolHandler: yielded = "" matched = None async for token in agen: - ctx.completion_tokens += 1 body += token matched = checker.check(body) if matched: break + ctx.completion_tokens += 1 yield self.builder.format_chunk(token) yielded += token @@ -168,7 +168,6 @@ class ProtocolHandler: matched = None async for token in agen: - ctx.completion_tokens += 1 chunks.append(token) body += token @@ -176,6 +175,8 @@ class ProtocolHandler: if matched: break + ctx.completion_tokens += 1 + content = "".join(chunks) stop = StopInfo(matched=matched, body=body) return self.builder.format_response(ctx, content, stop) diff --git a/astrai/inference/core/task.py b/astrai/inference/core/task.py index 1b449c8..5fcf0a4 100644 --- a/astrai/inference/core/task.py +++ b/astrai/inference/core/task.py @@ -186,7 +186,10 @@ class TaskManager: return bool(self.active_tasks or self.waiting_queue) def wait_for_tasks(self, timeout: float = 1.0): - self._task_event.clear() + with self._lock: + if self.waiting_queue or self.active_tasks: + return + self._task_event.clear() self._task_event.wait(timeout=timeout) def get_active_tasks(self) -> List[Task]: diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index c1f2141..566d0d7 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from torch.distributed.fsdp import FullStateDictConfig, StateDictType from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -115,8 +116,8 @@ class BaseExecutor: def backward(self, loss: torch.Tensor): loss.backward() - def unwrap_model(self, model: nn.Module) -> nn.Module: - return model + def unwrap_model(self, model: nn.Module): + return model.state_dict() @property def use_distributed(self) -> bool: @@ -195,10 +196,10 @@ class DDPExecutor(BaseExecutor): return model.no_sync() return contextlib.nullcontext() - def unwrap_model(self, model: nn.Module) -> nn.Module: + def unwrap_model(self, model: nn.Module): if isinstance(model, DDP): - return model.module - return model + return model.module.state_dict() + return model.state_dict() @ExecutorFactory.register("fsdp") @@ -259,9 +260,13 @@ class FSDPExecutor(BaseExecutor): return model.no_sync() return contextlib.nullcontext() - def unwrap_model(self, model: nn.Module) -> nn.Module: - if self._original_model is not None: - return self._original_model - if isinstance(model, FSDP): - return model._fsdp_wrapped_module - return model + def unwrap_model(self, model: nn.Module): + if isinstance(model, FSDP) and self.use_distributed: + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=False), + ): + return model.state_dict() + + return model.state_dict() diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 37ee0a8..bdf1538 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -1,6 +1,5 @@ """Training strategy implementations with factory pattern.""" -import copy from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Union @@ -8,28 +7,14 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.nn.parallel import DistributedDataParallel as DDP from astrai.factory import BaseFactory -def unwrap_model(model: nn.Module) -> nn.Module: - if isinstance(model, DDP): - return model.module - if isinstance(model, FSDP): - return model._fsdp_wrapped_module - return model - - -def create_ref_model(model: nn.Module) -> nn.Module: - """Create a reference model for DPO/GRPO training. - - Handles DDP-wrapped models safely by unwrapping first, - then creating a deep copy with frozen gradients. - """ - original_model = unwrap_model(model) - ref_model = copy.deepcopy(original_model) +def create_ref_model(model_fn, state_dict: dict) -> nn.Module: + """Create a frozen reference model from model_fn + full state dict.""" + ref_model = model_fn() + ref_model.load_state_dict(state_dict) ref_model.requires_grad_(False) ref_model.eval() return ref_model @@ -91,6 +76,8 @@ class BaseStrategy(ABC): ): self.model = model self.device = device + self.executor = kwargs.pop("executor", None) + self.model_fn = kwargs.pop("model_fn", None) self.extra_kwargs = kwargs @abstractmethod @@ -230,7 +217,9 @@ class DPOStrategy(BaseStrategy): **kwargs, ): super().__init__(model, device, **kwargs) - self.ref_model = create_ref_model(model) + self.ref_model = create_ref_model( + self.model_fn, self.executor.unwrap_model(model) + ) self.beta = beta self.reduction = reduction @@ -284,7 +273,9 @@ class GRPOStrategy(BaseStrategy): **kwargs, ): super().__init__(model, device, **kwargs) - self.ref_model = create_ref_model(model) + self.ref_model = create_ref_model( + self.model_fn, self.executor.unwrap_model(model) + ) self.clip_eps = clip_eps self.kl_coef = kl_coef self.group_size = group_size @@ -294,8 +285,7 @@ class GRPOStrategy(BaseStrategy): def sync_ref_model(self): """Copy current model weights to ref model.""" - ref_state = self.model.state_dict() - self.ref_model.load_state_dict(ref_state) + self.ref_model.load_state_dict(self.executor.unwrap_model(self.model)) def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: self._step += 1 diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 31f2260..225e4d9 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -146,8 +146,7 @@ class CheckpointCallback(TrainCallback): self.last_ckpt_iter = 0 def _save_checkpoint(self, context: TrainContext): - unwrapped = context.executor.unwrap_model(context.model) - state_dict = unwrapped.state_dict() + state_dict = context.executor.unwrap_model(context.model) self.last_ckpt_iter = context.iteration if get_rank() == 0: diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 879830b..9d268e1 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -162,6 +162,8 @@ class TrainContextBuilder: model=context.model, train_type=cfg.strategy, device=device, + executor=executor, + model_fn=cfg.model_fn, **cfg.extra_kwargs, ) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index b9c8cff..21662c8 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,4 +1,3 @@ -import json import os import numpy as np @@ -8,7 +7,6 @@ import torch from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.storage import ( H5Store, - MmapStore, StoreFactory, detect_format, load_bin,