diff --git a/astrai/__init__.py b/astrai/__init__.py index 46d5b22..4e5b30d 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -3,32 +3,98 @@ __author__ = "ViperEkura" from astrai.config import ( AutoRegressiveLMConfig, + BaseModelConfig, + ConfigFactory, EncoderConfig, + PipelineConfig, TrainConfig, ) -from astrai.dataset import DatasetFactory +from astrai.dataset import ( + BaseDataset, + DatasetFactory, + ResumableDistributedSampler, + Store, + StoreFactory, +) from astrai.factory import BaseFactory from astrai.inference import ( GenerationRequest, InferenceEngine, + ProtocolHandler, + SamplingPipeline, + get_app, + run_server, + sample, +) +from astrai.model import ( + AutoModel, + AutoRegressiveLM, + EmbeddingEncoder, + LoRAConfig, + inject_lora, +) +from astrai.parallel import ( + ExecutorFactory, + get_rank, + get_world_size, + only_on_rank, + spawn_parallel_fn, +) +from astrai.preprocessing import Pipeline, filter_by_length +from astrai.serialization import Checkpoint +from astrai.tokenize import AutoTokenizer, ChatTemplate +from astrai.trainer import ( + BaseScheduler, + BaseStrategy, + CallbackFactory, + Muon, + SchedulerFactory, + StrategyFactory, + TrainCallback, + Trainer, ) -from astrai.model import AutoModel, AutoRegressiveLM -from astrai.tokenize import AutoTokenizer -from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer __all__ = [ "AutoRegressiveLM", "AutoRegressiveLMConfig", - "EncoderConfig", - "TrainConfig", - "DatasetFactory", + "AutoModel", "AutoTokenizer", + "BaseDataset", + "BaseFactory", + "BaseModelConfig", + "BaseScheduler", + "BaseStrategy", + "CallbackFactory", + "ChatTemplate", + "Checkpoint", + "ConfigFactory", + "DatasetFactory", + "EmbeddingEncoder", + "EncoderConfig", + "ExecutorFactory", "GenerationRequest", "InferenceEngine", - "Trainer", - "CallbackFactory", - "StrategyFactory", + "LoRAConfig", + "Muon", + "Pipeline", + "PipelineConfig", + "ProtocolHandler", + "ResumableDistributedSampler", + "SamplingPipeline", "SchedulerFactory", - "BaseFactory", - "AutoModel", + "Store", + "StoreFactory", + "StrategyFactory", + "TrainCallback", + "TrainConfig", + "Trainer", + "filter_by_length", + "get_app", + "get_rank", + "get_world_size", + "inject_lora", + "only_on_rank", + "run_server", + "sample", + "spawn_parallel_fn", ] diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 1c1ca44..c06f3cb 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -70,8 +70,8 @@ class InferenceScheduler: dtype=self.dtype, ) - self._running = False - self._fatal_error: Optional[Exception] = None + self._stop_event = threading.Event() + self._loop_thread: Optional[threading.Thread] = None def add_task(self, prompt: str, **kwargs) -> str: return self._task_mgr.add_task(prompt, **kwargs) @@ -86,7 +86,7 @@ class InferenceScheduler: def _run_generation_loop(self): stop_ids = self._task_mgr.tokenizer.stop_ids try: - while self._running: + while not self._stop_event.is_set(): finished = self._task_mgr.remove_finished_tasks(stop_ids) for task in finished: self._page_cache.task_free(task.task_id) @@ -176,8 +176,7 @@ class InferenceScheduler: t.stream_callback(STOP) except Exception as e: - self._fatal_error = e - self._running = False + self._stop_event.set() logger.error(f"Scheduler loop crashed: {e}", exc_info=True) for task in self._task_mgr.get_active_tasks(): if task.stream_callback: @@ -189,17 +188,19 @@ class InferenceScheduler: self._task_mgr.clear_queues() def start(self): - if not self._running: - self._running = True - t = threading.Thread(target=self._run_generation_loop, daemon=True) - t.start() - self._loop_thread = t + if self._loop_thread is not None and self._loop_thread.is_alive(): + return + self._stop_event.clear() + t = threading.Thread(target=self._run_generation_loop, daemon=True) + t.start() + self._loop_thread = t def stop(self): - self._running = False + self._stop_event.set() self._task_mgr.wake() - if hasattr(self, "_loop_thread"): + if self._loop_thread is not None: self._loop_thread.join(timeout=2.0) + self._loop_thread = None for task in self._task_mgr.get_active_tasks(): if task.stream_callback: task.stream_callback(STOP)