fix : 并行训练 state_dict 收集与训练/推理并发缺陷

- FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True
- DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict()
- CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict
- strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数
- TrainContextBuilder: 传递 model_fn + executor 到 strategy
- GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重
- TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态
- ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式)
This commit is contained in:
ViperEkura 2026-05-29 21:12:24 +08:00
parent a3275423a4
commit d4451f6afb
7 changed files with 39 additions and 41 deletions

View File

@ -138,13 +138,13 @@ class ProtocolHandler:
yielded = "" yielded = ""
matched = None matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1
body += token body += token
matched = checker.check(body) matched = checker.check(body)
if matched: if matched:
break break
ctx.completion_tokens += 1
yield self.builder.format_chunk(token) yield self.builder.format_chunk(token)
yielded += token yielded += token
@ -168,7 +168,6 @@ class ProtocolHandler:
matched = None matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1
chunks.append(token) chunks.append(token)
body += token body += token
@ -176,6 +175,8 @@ class ProtocolHandler:
if matched: if matched:
break break
ctx.completion_tokens += 1
content = "".join(chunks) content = "".join(chunks)
stop = StopInfo(matched=matched, body=body) stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop) return self.builder.format_response(ctx, content, stop)

View File

@ -186,7 +186,10 @@ class TaskManager:
return bool(self.active_tasks or self.waiting_queue) return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0): def wait_for_tasks(self, timeout: float = 1.0):
self._task_event.clear() with self._lock:
if self.waiting_queue or self.active_tasks:
return
self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]: def get_active_tasks(self) -> List[Task]:

View File

@ -7,6 +7,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
@ -115,8 +116,8 @@ class BaseExecutor:
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
loss.backward() loss.backward()
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module):
return model return model.state_dict()
@property @property
def use_distributed(self) -> bool: def use_distributed(self) -> bool:
@ -195,10 +196,10 @@ class DDPExecutor(BaseExecutor):
return model.no_sync() return model.no_sync()
return contextlib.nullcontext() return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module):
if isinstance(model, DDP): if isinstance(model, DDP):
return model.module return model.module.state_dict()
return model return model.state_dict()
@ExecutorFactory.register("fsdp") @ExecutorFactory.register("fsdp")
@ -259,9 +260,13 @@ class FSDPExecutor(BaseExecutor):
return model.no_sync() return model.no_sync()
return contextlib.nullcontext() return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module):
if self._original_model is not None: if isinstance(model, FSDP) and self.use_distributed:
return self._original_model with FSDP.state_dict_type(
if isinstance(model, FSDP): model,
return model._fsdp_wrapped_module StateDictType.FULL_STATE_DICT,
return model FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
return model.state_dict()
return model.state_dict()

View File

@ -1,6 +1,5 @@
"""Training strategy implementations with factory pattern.""" """Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
@ -8,28 +7,14 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module: def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
if isinstance(model, DDP): """Create a frozen reference model from model_fn + full state dict."""
return model.module ref_model = model_fn()
if isinstance(model, FSDP): ref_model.load_state_dict(state_dict)
return model._fsdp_wrapped_module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
ref_model.requires_grad_(False) ref_model.requires_grad_(False)
ref_model.eval() ref_model.eval()
return ref_model return ref_model
@ -91,6 +76,8 @@ class BaseStrategy(ABC):
): ):
self.model = model self.model = model
self.device = device self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs self.extra_kwargs = kwargs
@abstractmethod @abstractmethod
@ -230,7 +217,9 @@ class DPOStrategy(BaseStrategy):
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
)
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
@ -284,7 +273,9 @@ class GRPOStrategy(BaseStrategy):
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
)
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size
@ -294,8 +285,7 @@ class GRPOStrategy(BaseStrategy):
def sync_ref_model(self): def sync_ref_model(self):
"""Copy current model weights to ref model.""" """Copy current model weights to ref model."""
ref_state = self.model.state_dict() self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
self.ref_model.load_state_dict(ref_state)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1 self._step += 1

View File

@ -146,8 +146,7 @@ class CheckpointCallback(TrainCallback):
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
unwrapped = context.executor.unwrap_model(context.model) state_dict = context.executor.unwrap_model(context.model)
state_dict = unwrapped.state_dict()
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
if get_rank() == 0: if get_rank() == 0:

View File

@ -162,6 +162,8 @@ class TrainContextBuilder:
model=context.model, model=context.model,
train_type=cfg.strategy, train_type=cfg.strategy,
device=device, device=device,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs, **cfg.extra_kwargs,
) )

View File

@ -1,4 +1,3 @@
import json
import os import os
import numpy as np import numpy as np
@ -8,7 +7,6 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, H5Store,
MmapStore,
StoreFactory, StoreFactory,
detect_format, detect_format,
load_bin, load_bin,