fix : 并行训练 state_dict 收集与训练/推理并发缺陷
- FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True - DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict() - CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict - strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数 - TrainContextBuilder: 传递 model_fn + executor 到 strategy - GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重 - TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态 - ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式)
This commit is contained in:
parent
a3275423a4
commit
d4451f6afb
|
|
@ -138,13 +138,13 @@ class ProtocolHandler:
|
||||||
yielded = ""
|
yielded = ""
|
||||||
matched = None
|
matched = None
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
|
||||||
body += token
|
body += token
|
||||||
|
|
||||||
matched = checker.check(body)
|
matched = checker.check(body)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
ctx.completion_tokens += 1
|
||||||
yield self.builder.format_chunk(token)
|
yield self.builder.format_chunk(token)
|
||||||
yielded += token
|
yielded += token
|
||||||
|
|
||||||
|
|
@ -168,7 +168,6 @@ class ProtocolHandler:
|
||||||
matched = None
|
matched = None
|
||||||
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
body += token
|
body += token
|
||||||
|
|
||||||
|
|
@ -176,6 +175,8 @@ class ProtocolHandler:
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
ctx.completion_tokens += 1
|
||||||
|
|
||||||
content = "".join(chunks)
|
content = "".join(chunks)
|
||||||
stop = StopInfo(matched=matched, body=body)
|
stop = StopInfo(matched=matched, body=body)
|
||||||
return self.builder.format_response(ctx, content, stop)
|
return self.builder.format_response(ctx, content, stop)
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,9 @@ class TaskManager:
|
||||||
return bool(self.active_tasks or self.waiting_queue)
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
||||||
def wait_for_tasks(self, timeout: float = 1.0):
|
def wait_for_tasks(self, timeout: float = 1.0):
|
||||||
|
with self._lock:
|
||||||
|
if self.waiting_queue or self.active_tasks:
|
||||||
|
return
|
||||||
self._task_event.clear()
|
self._task_event.clear()
|
||||||
self._task_event.wait(timeout=timeout)
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -115,8 +116,8 @@ class BaseExecutor:
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
def unwrap_model(self, model: nn.Module):
|
||||||
return model
|
return model.state_dict()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_distributed(self) -> bool:
|
def use_distributed(self) -> bool:
|
||||||
|
|
@ -195,10 +196,10 @@ class DDPExecutor(BaseExecutor):
|
||||||
return model.no_sync()
|
return model.no_sync()
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
def unwrap_model(self, model: nn.Module):
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
return model.module
|
return model.module.state_dict()
|
||||||
return model
|
return model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("fsdp")
|
@ExecutorFactory.register("fsdp")
|
||||||
|
|
@ -259,9 +260,13 @@ class FSDPExecutor(BaseExecutor):
|
||||||
return model.no_sync()
|
return model.no_sync()
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
def unwrap_model(self, model: nn.Module):
|
||||||
if self._original_model is not None:
|
if isinstance(model, FSDP) and self.use_distributed:
|
||||||
return self._original_model
|
with FSDP.state_dict_type(
|
||||||
if isinstance(model, FSDP):
|
model,
|
||||||
return model._fsdp_wrapped_module
|
StateDictType.FULL_STATE_DICT,
|
||||||
return model
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||||
|
):
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
return model.state_dict()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""Training strategy implementations with factory pattern."""
|
"""Training strategy implementations with factory pattern."""
|
||||||
|
|
||||||
import copy
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
|
|
@ -8,28 +7,14 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||||
if isinstance(model, DDP):
|
"""Create a frozen reference model from model_fn + full state dict."""
|
||||||
return model.module
|
ref_model = model_fn()
|
||||||
if isinstance(model, FSDP):
|
ref_model.load_state_dict(state_dict)
|
||||||
return model._fsdp_wrapped_module
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
|
||||||
"""Create a reference model for DPO/GRPO training.
|
|
||||||
|
|
||||||
Handles DDP-wrapped models safely by unwrapping first,
|
|
||||||
then creating a deep copy with frozen gradients.
|
|
||||||
"""
|
|
||||||
original_model = unwrap_model(model)
|
|
||||||
ref_model = copy.deepcopy(original_model)
|
|
||||||
ref_model.requires_grad_(False)
|
ref_model.requires_grad_(False)
|
||||||
ref_model.eval()
|
ref_model.eval()
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
@ -91,6 +76,8 @@ class BaseStrategy(ABC):
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.executor = kwargs.pop("executor", None)
|
||||||
|
self.model_fn = kwargs.pop("model_fn", None)
|
||||||
self.extra_kwargs = kwargs
|
self.extra_kwargs = kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -230,7 +217,9 @@ class DPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(
|
||||||
|
self.model_fn, self.executor.unwrap_model(model)
|
||||||
|
)
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
|
|
@ -284,7 +273,9 @@ class GRPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(
|
||||||
|
self.model_fn, self.executor.unwrap_model(model)
|
||||||
|
)
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
|
@ -294,8 +285,7 @@ class GRPOStrategy(BaseStrategy):
|
||||||
|
|
||||||
def sync_ref_model(self):
|
def sync_ref_model(self):
|
||||||
"""Copy current model weights to ref model."""
|
"""Copy current model weights to ref model."""
|
||||||
ref_state = self.model.state_dict()
|
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
||||||
self.ref_model.load_state_dict(ref_state)
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
self._step += 1
|
self._step += 1
|
||||||
|
|
|
||||||
|
|
@ -146,8 +146,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
unwrapped = context.executor.unwrap_model(context.model)
|
state_dict = context.executor.unwrap_model(context.model)
|
||||||
state_dict = unwrapped.state_dict()
|
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
if get_rank() == 0:
|
if get_rank() == 0:
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,8 @@ class TrainContextBuilder:
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=cfg.strategy,
|
train_type=cfg.strategy,
|
||||||
device=device,
|
device=device,
|
||||||
|
executor=executor,
|
||||||
|
model_fn=cfg.model_fn,
|
||||||
**cfg.extra_kwargs,
|
**cfg.extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -8,7 +7,6 @@ import torch
|
||||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
H5Store,
|
||||||
MmapStore,
|
|
||||||
StoreFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_bin,
|
load_bin,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue