fix : 并行训练 state_dict 收集与训练/推理并发缺陷

- FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True
- DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict()
- CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict
- strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数
- TrainContextBuilder: 传递 model_fn + executor 到 strategy
- GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重
- TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态
- ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式)
This commit is contained in:
ViperEkura 2026-05-29 21:12:24 +08:00
parent a3275423a4
commit d4451f6afb
7 changed files with 39 additions and 41 deletions

View File

@ -138,13 +138,13 @@ class ProtocolHandler:
yielded = ""
matched = None
async for token in agen:
ctx.completion_tokens += 1
body += token
matched = checker.check(body)
if matched:
break
ctx.completion_tokens += 1
yield self.builder.format_chunk(token)
yielded += token
@ -168,7 +168,6 @@ class ProtocolHandler:
matched = None
async for token in agen:
ctx.completion_tokens += 1
chunks.append(token)
body += token
@ -176,6 +175,8 @@ class ProtocolHandler:
if matched:
break
ctx.completion_tokens += 1
content = "".join(chunks)
stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop)

View File

@ -186,7 +186,10 @@ class TaskManager:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0):
self._task_event.clear()
with self._lock:
if self.waiting_queue or self.active_tasks:
return
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:

View File

@ -7,6 +7,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -115,8 +116,8 @@ class BaseExecutor:
def backward(self, loss: torch.Tensor):
loss.backward()
def unwrap_model(self, model: nn.Module) -> nn.Module:
return model
def unwrap_model(self, model: nn.Module):
return model.state_dict()
@property
def use_distributed(self) -> bool:
@ -195,10 +196,10 @@ class DDPExecutor(BaseExecutor):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module:
def unwrap_model(self, model: nn.Module):
if isinstance(model, DDP):
return model.module
return model
return model.module.state_dict()
return model.state_dict()
@ExecutorFactory.register("fsdp")
@ -259,9 +260,13 @@ class FSDPExecutor(BaseExecutor):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module:
if self._original_model is not None:
return self._original_model
if isinstance(model, FSDP):
return model._fsdp_wrapped_module
return model
def unwrap_model(self, model: nn.Module):
if isinstance(model, FSDP) and self.use_distributed:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
return model.state_dict()
return model.state_dict()

View File

@ -1,6 +1,5 @@
"""Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
@ -8,28 +7,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module:
if isinstance(model, DDP):
return model.module
if isinstance(model, FSDP):
return model._fsdp_wrapped_module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict."""
ref_model = model_fn()
ref_model.load_state_dict(state_dict)
ref_model.requires_grad_(False)
ref_model.eval()
return ref_model
@ -91,6 +76,8 @@ class BaseStrategy(ABC):
):
self.model = model
self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs
@abstractmethod
@ -230,7 +217,9 @@ class DPOStrategy(BaseStrategy):
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
)
self.beta = beta
self.reduction = reduction
@ -284,7 +273,9 @@ class GRPOStrategy(BaseStrategy):
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
)
self.clip_eps = clip_eps
self.kl_coef = kl_coef
self.group_size = group_size
@ -294,8 +285,7 @@ class GRPOStrategy(BaseStrategy):
def sync_ref_model(self):
"""Copy current model weights to ref model."""
ref_state = self.model.state_dict()
self.ref_model.load_state_dict(ref_state)
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1

View File

@ -146,8 +146,7 @@ class CheckpointCallback(TrainCallback):
self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext):
unwrapped = context.executor.unwrap_model(context.model)
state_dict = unwrapped.state_dict()
state_dict = context.executor.unwrap_model(context.model)
self.last_ckpt_iter = context.iteration
if get_rank() == 0:

View File

@ -162,6 +162,8 @@ class TrainContextBuilder:
model=context.model,
train_type=cfg.strategy,
device=device,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs,
)

View File

@ -1,4 +1,3 @@
import json
import os
import numpy as np
@ -8,7 +7,6 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
H5Store,
MmapStore,
StoreFactory,
detect_format,
load_bin,