fix : FSDP 优化器顺序、温度除零、调度器静默死亡、ref模型设备

- executor: use_orig_params 硬编码 True,FSDP 不替换 Parameter 对象
- strategy: DPO/GRPO ref 模型创建后移到 device
- sample: TemperatureStrategy clamp 1e-8,engine 验证改为 >0
- scheduler: 异常不 re-raise 避免 daemon 静默死亡,stop() 发回调给 waiting 任务
This commit is contained in:
ViperEkura 2026-05-29 21:57:44 +08:00
parent d4451f6afb
commit f521a30b22
5 changed files with 17 additions and 9 deletions

View File

@ -71,6 +71,7 @@ class InferenceScheduler:
) )
self._running = False self._running = False
self._fatal_error: Optional[Exception] = None
def add_task(self, prompt: str, **kwargs) -> str: def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
@ -175,6 +176,8 @@ class InferenceScheduler:
t.stream_callback(STOP) t.stream_callback(STOP)
except Exception as e: except Exception as e:
self._fatal_error = e
self._running = False
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback: if task.stream_callback:
@ -184,7 +187,6 @@ class InferenceScheduler:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise
def start(self): def start(self):
if not self._running: if not self._running:
@ -199,7 +201,12 @@ class InferenceScheduler:
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -79,8 +79,8 @@ class GenerationRequest:
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0): if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0") raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0): if not (isinstance(temperature, (int, float)) and temperature > 0):
raise ValueError("temperature must be a non-negative number") raise ValueError("temperature must be a positive number")
self.messages = messages self.messages = messages
self.top_k = top_k self.top_k = top_k

View File

@ -44,10 +44,12 @@ class TemperatureStrategy(BaseSamplingStrategy):
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
t = self.temperature t = self.temperature
if isinstance(t, Tensor): if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
t = torch.clamp(t, min=1e-8)
if (t != 1.0).any(): if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / t logits = logits / t
elif t != 1.0:
logits = logits / max(t, 1e-8)
return logits return logits

View File

@ -218,7 +218,6 @@ class FSDPExecutor(BaseExecutor):
sync_module_states: bool = False, sync_module_states: bool = False,
forward_prefetch: bool = False, forward_prefetch: bool = False,
limit_all_gathers: bool = True, limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None, ignored_states=None,
device_mesh=None, device_mesh=None,
): ):
@ -237,7 +236,7 @@ class FSDPExecutor(BaseExecutor):
sync_module_states=sync_module_states, sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch, forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers, limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params, use_orig_params=True,
ignored_states=ignored_states, ignored_states=ignored_states,
device_mesh=device_mesh, device_mesh=device_mesh,
).items() ).items()

View File

@ -219,7 +219,7 @@ class DPOStrategy(BaseStrategy):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model( self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model) self.model_fn, self.executor.unwrap_model(model)
) ).to(device=self.device)
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
@ -275,7 +275,7 @@ class GRPOStrategy(BaseStrategy):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model( self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model) self.model_fn, self.executor.unwrap_model(model)
) ).to(device=self.device)
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size