fix : FSDP 优化器顺序、温度除零、调度器静默死亡、ref模型设备
- executor: use_orig_params 硬编码 True,FSDP 不替换 Parameter 对象 - strategy: DPO/GRPO ref 模型创建后移到 device - sample: TemperatureStrategy clamp 1e-8,engine 验证改为 >0 - scheduler: 异常不 re-raise 避免 daemon 静默死亡,stop() 发回调给 waiting 任务
This commit is contained in:
parent
d4451f6afb
commit
f521a30b22
|
|
@ -71,6 +71,7 @@ class InferenceScheduler:
|
|||
)
|
||||
|
||||
self._running = False
|
||||
self._fatal_error: Optional[Exception] = None
|
||||
|
||||
def add_task(self, prompt: str, **kwargs) -> str:
|
||||
return self._task_mgr.add_task(prompt, **kwargs)
|
||||
|
|
@ -175,6 +176,8 @@ class InferenceScheduler:
|
|||
t.stream_callback(STOP)
|
||||
|
||||
except Exception as e:
|
||||
self._fatal_error = e
|
||||
self._running = False
|
||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
|
|
@ -184,7 +187,6 @@ class InferenceScheduler:
|
|||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._task_mgr.clear_queues()
|
||||
raise
|
||||
|
||||
def start(self):
|
||||
if not self._running:
|
||||
|
|
@ -199,7 +201,12 @@ class InferenceScheduler:
|
|||
if hasattr(self, "_loop_thread"):
|
||||
self._loop_thread.join(timeout=2.0)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
for task in self._task_mgr.get_waiting_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._task_mgr.clear_queues()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -79,8 +79,8 @@ class GenerationRequest:
|
|||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
||||
raise ValueError("temperature must be a positive number")
|
||||
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
|
|
|
|||
|
|
@ -44,10 +44,12 @@ class TemperatureStrategy(BaseSamplingStrategy):
|
|||
def apply(self, logits, filter_value=-float("inf")):
|
||||
t = self.temperature
|
||||
if isinstance(t, Tensor):
|
||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||
t = torch.clamp(t, min=1e-8)
|
||||
if (t != 1.0).any():
|
||||
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||
logits = logits / t
|
||||
elif t != 1.0:
|
||||
logits = logits / t
|
||||
logits = logits / max(t, 1e-8)
|
||||
return logits
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -218,7 +218,6 @@ class FSDPExecutor(BaseExecutor):
|
|||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = True,
|
||||
use_orig_params: bool = False,
|
||||
ignored_states=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
|
|
@ -237,7 +236,7 @@ class FSDPExecutor(BaseExecutor):
|
|||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=use_orig_params,
|
||||
use_orig_params=True,
|
||||
ignored_states=ignored_states,
|
||||
device_mesh=device_mesh,
|
||||
).items()
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ class DPOStrategy(BaseStrategy):
|
|||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
)
|
||||
).to(device=self.device)
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
|
||||
|
|
@ -275,7 +275,7 @@ class GRPOStrategy(BaseStrategy):
|
|||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
)
|
||||
).to(device=self.device)
|
||||
self.clip_eps = clip_eps
|
||||
self.kl_coef = kl_coef
|
||||
self.group_size = group_size
|
||||
|
|
|
|||
Loading…
Reference in New Issue