fix : 并行训练 state_dict 收集与训练/推理并发缺陷
- FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True - DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict() - CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict - strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数 - TrainContextBuilder: 传递 model_fn + executor 到 strategy - GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重 - TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态 - ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式)
This commit is contained in:
parent
a3275423a4
commit
d4451f6afb
|
|
@ -138,13 +138,13 @@ class ProtocolHandler:
|
|||
yielded = ""
|
||||
matched = None
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
body += token
|
||||
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
yield self.builder.format_chunk(token)
|
||||
yielded += token
|
||||
|
||||
|
|
@ -168,7 +168,6 @@ class ProtocolHandler:
|
|||
matched = None
|
||||
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
chunks.append(token)
|
||||
body += token
|
||||
|
||||
|
|
@ -176,6 +175,8 @@ class ProtocolHandler:
|
|||
if matched:
|
||||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
|
||||
content = "".join(chunks)
|
||||
stop = StopInfo(matched=matched, body=body)
|
||||
return self.builder.format_response(ctx, content, stop)
|
||||
|
|
|
|||
|
|
@ -186,6 +186,9 @@ class TaskManager:
|
|||
return bool(self.active_tasks or self.waiting_queue)
|
||||
|
||||
def wait_for_tasks(self, timeout: float = 1.0):
|
||||
with self._lock:
|
||||
if self.waiting_queue or self.active_tasks:
|
||||
return
|
||||
self._task_event.clear()
|
||||
self._task_event.wait(timeout=timeout)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -115,8 +116,8 @@ class BaseExecutor:
|
|||
def backward(self, loss: torch.Tensor):
|
||||
loss.backward()
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
return model.state_dict()
|
||||
|
||||
@property
|
||||
def use_distributed(self) -> bool:
|
||||
|
|
@ -195,10 +196,10 @@ class DDPExecutor(BaseExecutor):
|
|||
return model.no_sync()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
return model
|
||||
return model.module.state_dict()
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
@ExecutorFactory.register("fsdp")
|
||||
|
|
@ -259,9 +260,13 @@ class FSDPExecutor(BaseExecutor):
|
|||
return model.no_sync()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
if self._original_model is not None:
|
||||
return self._original_model
|
||||
if isinstance(model, FSDP):
|
||||
return model._fsdp_wrapped_module
|
||||
return model
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
if isinstance(model, FSDP) and self.use_distributed:
|
||||
with FSDP.state_dict_type(
|
||||
model,
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||
):
|
||||
return model.state_dict()
|
||||
|
||||
return model.state_dict()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Training strategy implementations with factory pattern."""
|
||||
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
|
|
@ -8,28 +7,14 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
if isinstance(model, FSDP):
|
||||
return model._fsdp_wrapped_module
|
||||
return model
|
||||
|
||||
|
||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||
"""Create a reference model for DPO/GRPO training.
|
||||
|
||||
Handles DDP-wrapped models safely by unwrapping first,
|
||||
then creating a deep copy with frozen gradients.
|
||||
"""
|
||||
original_model = unwrap_model(model)
|
||||
ref_model = copy.deepcopy(original_model)
|
||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||
"""Create a frozen reference model from model_fn + full state dict."""
|
||||
ref_model = model_fn()
|
||||
ref_model.load_state_dict(state_dict)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
return ref_model
|
||||
|
|
@ -91,6 +76,8 @@ class BaseStrategy(ABC):
|
|||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.executor = kwargs.pop("executor", None)
|
||||
self.model_fn = kwargs.pop("model_fn", None)
|
||||
self.extra_kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -230,7 +217,9 @@ class DPOStrategy(BaseStrategy):
|
|||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
)
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
|
||||
|
|
@ -284,7 +273,9 @@ class GRPOStrategy(BaseStrategy):
|
|||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
)
|
||||
self.clip_eps = clip_eps
|
||||
self.kl_coef = kl_coef
|
||||
self.group_size = group_size
|
||||
|
|
@ -294,8 +285,7 @@ class GRPOStrategy(BaseStrategy):
|
|||
|
||||
def sync_ref_model(self):
|
||||
"""Copy current model weights to ref model."""
|
||||
ref_state = self.model.state_dict()
|
||||
self.ref_model.load_state_dict(ref_state)
|
||||
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
self._step += 1
|
||||
|
|
|
|||
|
|
@ -146,8 +146,7 @@ class CheckpointCallback(TrainCallback):
|
|||
self.last_ckpt_iter = 0
|
||||
|
||||
def _save_checkpoint(self, context: TrainContext):
|
||||
unwrapped = context.executor.unwrap_model(context.model)
|
||||
state_dict = unwrapped.state_dict()
|
||||
state_dict = context.executor.unwrap_model(context.model)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
if get_rank() == 0:
|
||||
|
|
|
|||
|
|
@ -162,6 +162,8 @@ class TrainContextBuilder:
|
|||
model=context.model,
|
||||
train_type=cfg.strategy,
|
||||
device=device,
|
||||
executor=executor,
|
||||
model_fn=cfg.model_fn,
|
||||
**cfg.extra_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -8,7 +7,6 @@ import torch
|
|||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||
from astrai.dataset.storage import (
|
||||
H5Store,
|
||||
MmapStore,
|
||||
StoreFactory,
|
||||
detect_format,
|
||||
load_bin,
|
||||
|
|
|
|||
Loading…
Reference in New Issue