diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 3e76f77..1c1ca44 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -71,6 +71,7 @@ class InferenceScheduler: ) self._running = False + self._fatal_error: Optional[Exception] = None def add_task(self, prompt: str, **kwargs) -> str: return self._task_mgr.add_task(prompt, **kwargs) @@ -175,6 +176,8 @@ class InferenceScheduler: t.stream_callback(STOP) except Exception as e: + self._fatal_error = e + self._running = False logger.error(f"Scheduler loop crashed: {e}", exc_info=True) for task in self._task_mgr.get_active_tasks(): if task.stream_callback: @@ -184,7 +187,6 @@ class InferenceScheduler: if task.stream_callback: task.stream_callback(STOP) self._task_mgr.clear_queues() - raise def start(self): if not self._running: @@ -199,7 +201,12 @@ class InferenceScheduler: if hasattr(self, "_loop_thread"): self._loop_thread.join(timeout=2.0) for task in self._task_mgr.get_active_tasks(): + if task.stream_callback: + task.stream_callback(STOP) self._page_cache.task_free(task.task_id) + for task in self._task_mgr.get_waiting_tasks(): + if task.stream_callback: + task.stream_callback(STOP) self._task_mgr.clear_queues() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 63b28ed..c3a0c00 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -79,8 +79,8 @@ class GenerationRequest: raise ValueError("top_k must be a non-negative integer") if not (0.0 <= top_p <= 1.0): raise ValueError("top_p must be a float between 0.0 and 1.0") - if not (isinstance(temperature, (int, float)) and temperature >= 0): - raise ValueError("temperature must be a non-negative number") + if not (isinstance(temperature, (int, float)) and temperature > 0): + raise ValueError("temperature must be a positive number") self.messages = messages self.top_k = top_k diff --git a/astrai/inference/sample.py b/astrai/inference/sample.py index 45949ac..cb007df 100644 --- a/astrai/inference/sample.py +++ b/astrai/inference/sample.py @@ -44,10 +44,12 @@ class TemperatureStrategy(BaseSamplingStrategy): def apply(self, logits, filter_value=-float("inf")): t = self.temperature if isinstance(t, Tensor): + t = t.to(logits.device, non_blocking=True).view(-1, 1) + t = torch.clamp(t, min=1e-8) if (t != 1.0).any(): - logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1) + logits = logits / t elif t != 1.0: - logits = logits / t + logits = logits / max(t, 1e-8) return logits diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index 566d0d7..ce2d935 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -218,7 +218,6 @@ class FSDPExecutor(BaseExecutor): sync_module_states: bool = False, forward_prefetch: bool = False, limit_all_gathers: bool = True, - use_orig_params: bool = False, ignored_states=None, device_mesh=None, ): @@ -237,7 +236,7 @@ class FSDPExecutor(BaseExecutor): sync_module_states=sync_module_states, forward_prefetch=forward_prefetch, limit_all_gathers=limit_all_gathers, - use_orig_params=use_orig_params, + use_orig_params=True, ignored_states=ignored_states, device_mesh=device_mesh, ).items() diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index bdf1538..529edd1 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -219,7 +219,7 @@ class DPOStrategy(BaseStrategy): super().__init__(model, device, **kwargs) self.ref_model = create_ref_model( self.model_fn, self.executor.unwrap_model(model) - ) + ).to(device=self.device) self.beta = beta self.reduction = reduction @@ -275,7 +275,7 @@ class GRPOStrategy(BaseStrategy): super().__init__(model, device, **kwargs) self.ref_model = create_ref_model( self.model_fn, self.executor.unwrap_model(model) - ) + ).to(device=self.device) self.clip_eps = clip_eps self.kl_coef = kl_coef self.group_size = group_size