fix: 使用 threading.Event 替代裸 bool,补全公共 API

- scheduler 停止信号改用 threading.Event,跨解释器安全
- 移除 _fatal_error 和 check_health,异常仅用 logger.error 记录
- 补全 astrai/__init__.py,暴露所有主要模块
This commit is contained in:
ViperEkura 2026-06-18 15:38:13 +08:00
parent 7a04b1f8ce
commit 3e234c46f6
2 changed files with 91 additions and 24 deletions

View File

@ -3,32 +3,98 @@ __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (
AutoRegressiveLMConfig, AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig, EncoderConfig,
PipelineConfig,
TrainConfig, TrainConfig,
) )
from astrai.dataset import DatasetFactory from astrai.dataset import (
BaseDataset,
DatasetFactory,
ResumableDistributedSampler,
Store,
StoreFactory,
)
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.inference import ( from astrai.inference import (
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
ProtocolHandler,
SamplingPipeline,
get_app,
run_server,
sample,
)
from astrai.model import (
AutoModel,
AutoRegressiveLM,
EmbeddingEncoder,
LoRAConfig,
inject_lora,
)
from astrai.parallel import (
ExecutorFactory,
get_rank,
get_world_size,
only_on_rank,
spawn_parallel_fn,
)
from astrai.preprocessing import Pipeline, filter_by_length
from astrai.serialization import Checkpoint
from astrai.tokenize import AutoTokenizer, ChatTemplate
from astrai.trainer import (
BaseScheduler,
BaseStrategy,
CallbackFactory,
Muon,
SchedulerFactory,
StrategyFactory,
TrainCallback,
Trainer,
) )
from astrai.model import AutoModel, AutoRegressiveLM
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [ __all__ = [
"AutoRegressiveLM", "AutoRegressiveLM",
"AutoRegressiveLMConfig", "AutoRegressiveLMConfig",
"EncoderConfig", "AutoModel",
"TrainConfig",
"DatasetFactory",
"AutoTokenizer", "AutoTokenizer",
"BaseDataset",
"BaseFactory",
"BaseModelConfig",
"BaseScheduler",
"BaseStrategy",
"CallbackFactory",
"ChatTemplate",
"Checkpoint",
"ConfigFactory",
"DatasetFactory",
"EmbeddingEncoder",
"EncoderConfig",
"ExecutorFactory",
"GenerationRequest", "GenerationRequest",
"InferenceEngine", "InferenceEngine",
"Trainer", "LoRAConfig",
"CallbackFactory", "Muon",
"StrategyFactory", "Pipeline",
"PipelineConfig",
"ProtocolHandler",
"ResumableDistributedSampler",
"SamplingPipeline",
"SchedulerFactory", "SchedulerFactory",
"BaseFactory", "Store",
"AutoModel", "StoreFactory",
"StrategyFactory",
"TrainCallback",
"TrainConfig",
"Trainer",
"filter_by_length",
"get_app",
"get_rank",
"get_world_size",
"inject_lora",
"only_on_rank",
"run_server",
"sample",
"spawn_parallel_fn",
] ]

View File

@ -70,8 +70,8 @@ class InferenceScheduler:
dtype=self.dtype, dtype=self.dtype,
) )
self._running = False self._stop_event = threading.Event()
self._fatal_error: Optional[Exception] = None self._loop_thread: Optional[threading.Thread] = None
def add_task(self, prompt: str, **kwargs) -> str: def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
@ -86,7 +86,7 @@ class InferenceScheduler:
def _run_generation_loop(self): def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids stop_ids = self._task_mgr.tokenizer.stop_ids
try: try:
while self._running: while not self._stop_event.is_set():
finished = self._task_mgr.remove_finished_tasks(stop_ids) finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished: for task in finished:
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
@ -176,8 +176,7 @@ class InferenceScheduler:
t.stream_callback(STOP) t.stream_callback(STOP)
except Exception as e: except Exception as e:
self._fatal_error = e self._stop_event.set()
self._running = False
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback: if task.stream_callback:
@ -189,17 +188,19 @@ class InferenceScheduler:
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
def start(self): def start(self):
if not self._running: if self._loop_thread is not None and self._loop_thread.is_alive():
self._running = True return
t = threading.Thread(target=self._run_generation_loop, daemon=True) self._stop_event.clear()
t.start() t = threading.Thread(target=self._run_generation_loop, daemon=True)
self._loop_thread = t t.start()
self._loop_thread = t
def stop(self): def stop(self):
self._running = False self._stop_event.set()
self._task_mgr.wake() self._task_mgr.wake()
if hasattr(self, "_loop_thread"): if self._loop_thread is not None:
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=2.0)
self._loop_thread = None
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)