fix: 使用 threading.Event 替代裸 bool,补全公共 API

- scheduler 停止信号改用 threading.Event,跨解释器安全
- 移除 _fatal_error 和 check_health,异常仅用 logger.error 记录
- 补全 astrai/__init__.py,暴露所有主要模块
This commit is contained in:
ViperEkura 2026-06-18 15:38:13 +08:00
parent 7a04b1f8ce
commit 3e234c46f6
2 changed files with 91 additions and 24 deletions

View File

@ -3,32 +3,98 @@ __author__ = "ViperEkura"
from astrai.config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
PipelineConfig,
TrainConfig,
)
from astrai.dataset import DatasetFactory
from astrai.dataset import (
BaseDataset,
DatasetFactory,
ResumableDistributedSampler,
Store,
StoreFactory,
)
from astrai.factory import BaseFactory
from astrai.inference import (
GenerationRequest,
InferenceEngine,
ProtocolHandler,
SamplingPipeline,
get_app,
run_server,
sample,
)
from astrai.model import (
AutoModel,
AutoRegressiveLM,
EmbeddingEncoder,
LoRAConfig,
inject_lora,
)
from astrai.parallel import (
ExecutorFactory,
get_rank,
get_world_size,
only_on_rank,
spawn_parallel_fn,
)
from astrai.preprocessing import Pipeline, filter_by_length
from astrai.serialization import Checkpoint
from astrai.tokenize import AutoTokenizer, ChatTemplate
from astrai.trainer import (
BaseScheduler,
BaseStrategy,
CallbackFactory,
Muon,
SchedulerFactory,
StrategyFactory,
TrainCallback,
Trainer,
)
from astrai.model import AutoModel, AutoRegressiveLM
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"AutoRegressiveLM",
"AutoRegressiveLMConfig",
"EncoderConfig",
"TrainConfig",
"DatasetFactory",
"AutoModel",
"AutoTokenizer",
"BaseDataset",
"BaseFactory",
"BaseModelConfig",
"BaseScheduler",
"BaseStrategy",
"CallbackFactory",
"ChatTemplate",
"Checkpoint",
"ConfigFactory",
"DatasetFactory",
"EmbeddingEncoder",
"EncoderConfig",
"ExecutorFactory",
"GenerationRequest",
"InferenceEngine",
"Trainer",
"CallbackFactory",
"StrategyFactory",
"LoRAConfig",
"Muon",
"Pipeline",
"PipelineConfig",
"ProtocolHandler",
"ResumableDistributedSampler",
"SamplingPipeline",
"SchedulerFactory",
"BaseFactory",
"AutoModel",
"Store",
"StoreFactory",
"StrategyFactory",
"TrainCallback",
"TrainConfig",
"Trainer",
"filter_by_length",
"get_app",
"get_rank",
"get_world_size",
"inject_lora",
"only_on_rank",
"run_server",
"sample",
"spawn_parallel_fn",
]

View File

@ -70,8 +70,8 @@ class InferenceScheduler:
dtype=self.dtype,
)
self._running = False
self._fatal_error: Optional[Exception] = None
self._stop_event = threading.Event()
self._loop_thread: Optional[threading.Thread] = None
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
@ -86,7 +86,7 @@ class InferenceScheduler:
def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
while not self._stop_event.is_set():
finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished:
self._page_cache.task_free(task.task_id)
@ -176,8 +176,7 @@ class InferenceScheduler:
t.stream_callback(STOP)
except Exception as e:
self._fatal_error = e
self._running = False
self._stop_event.set()
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
@ -189,17 +188,19 @@ class InferenceScheduler:
self._task_mgr.clear_queues()
def start(self):
if not self._running:
self._running = True
if self._loop_thread is not None and self._loop_thread.is_alive():
return
self._stop_event.clear()
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self):
self._running = False
self._stop_event.set()
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
if self._loop_thread is not None:
self._loop_thread.join(timeout=2.0)
self._loop_thread = None
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)