fix: 使用 threading.Event 替代裸 bool,补全公共 API
- scheduler 停止信号改用 threading.Event,跨解释器安全 - 移除 _fatal_error 和 check_health,异常仅用 logger.error 记录 - 补全 astrai/__init__.py,暴露所有主要模块
This commit is contained in:
parent
7a04b1f8ce
commit
3e234c46f6
|
|
@ -3,32 +3,98 @@ __author__ = "ViperEkura"
|
|||
|
||||
from astrai.config import (
|
||||
AutoRegressiveLMConfig,
|
||||
BaseModelConfig,
|
||||
ConfigFactory,
|
||||
EncoderConfig,
|
||||
PipelineConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.dataset import (
|
||||
BaseDataset,
|
||||
DatasetFactory,
|
||||
ResumableDistributedSampler,
|
||||
Store,
|
||||
StoreFactory,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference import (
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
ProtocolHandler,
|
||||
SamplingPipeline,
|
||||
get_app,
|
||||
run_server,
|
||||
sample,
|
||||
)
|
||||
from astrai.model import (
|
||||
AutoModel,
|
||||
AutoRegressiveLM,
|
||||
EmbeddingEncoder,
|
||||
LoRAConfig,
|
||||
inject_lora,
|
||||
)
|
||||
from astrai.parallel import (
|
||||
ExecutorFactory,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
only_on_rank,
|
||||
spawn_parallel_fn,
|
||||
)
|
||||
from astrai.preprocessing import Pipeline, filter_by_length
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.tokenize import AutoTokenizer, ChatTemplate
|
||||
from astrai.trainer import (
|
||||
BaseScheduler,
|
||||
BaseStrategy,
|
||||
CallbackFactory,
|
||||
Muon,
|
||||
SchedulerFactory,
|
||||
StrategyFactory,
|
||||
TrainCallback,
|
||||
Trainer,
|
||||
)
|
||||
from astrai.model import AutoModel, AutoRegressiveLM
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||
|
||||
__all__ = [
|
||||
"AutoRegressiveLM",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"TrainConfig",
|
||||
"DatasetFactory",
|
||||
"AutoModel",
|
||||
"AutoTokenizer",
|
||||
"BaseDataset",
|
||||
"BaseFactory",
|
||||
"BaseModelConfig",
|
||||
"BaseScheduler",
|
||||
"BaseStrategy",
|
||||
"CallbackFactory",
|
||||
"ChatTemplate",
|
||||
"Checkpoint",
|
||||
"ConfigFactory",
|
||||
"DatasetFactory",
|
||||
"EmbeddingEncoder",
|
||||
"EncoderConfig",
|
||||
"ExecutorFactory",
|
||||
"GenerationRequest",
|
||||
"InferenceEngine",
|
||||
"Trainer",
|
||||
"CallbackFactory",
|
||||
"StrategyFactory",
|
||||
"LoRAConfig",
|
||||
"Muon",
|
||||
"Pipeline",
|
||||
"PipelineConfig",
|
||||
"ProtocolHandler",
|
||||
"ResumableDistributedSampler",
|
||||
"SamplingPipeline",
|
||||
"SchedulerFactory",
|
||||
"BaseFactory",
|
||||
"AutoModel",
|
||||
"Store",
|
||||
"StoreFactory",
|
||||
"StrategyFactory",
|
||||
"TrainCallback",
|
||||
"TrainConfig",
|
||||
"Trainer",
|
||||
"filter_by_length",
|
||||
"get_app",
|
||||
"get_rank",
|
||||
"get_world_size",
|
||||
"inject_lora",
|
||||
"only_on_rank",
|
||||
"run_server",
|
||||
"sample",
|
||||
"spawn_parallel_fn",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -70,8 +70,8 @@ class InferenceScheduler:
|
|||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self._running = False
|
||||
self._fatal_error: Optional[Exception] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._loop_thread: Optional[threading.Thread] = None
|
||||
|
||||
def add_task(self, prompt: str, **kwargs) -> str:
|
||||
return self._task_mgr.add_task(prompt, **kwargs)
|
||||
|
|
@ -86,7 +86,7 @@ class InferenceScheduler:
|
|||
def _run_generation_loop(self):
|
||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||
try:
|
||||
while self._running:
|
||||
while not self._stop_event.is_set():
|
||||
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
||||
for task in finished:
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
|
@ -176,8 +176,7 @@ class InferenceScheduler:
|
|||
t.stream_callback(STOP)
|
||||
|
||||
except Exception as e:
|
||||
self._fatal_error = e
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
|
|
@ -189,17 +188,19 @@ class InferenceScheduler:
|
|||
self._task_mgr.clear_queues()
|
||||
|
||||
def start(self):
|
||||
if not self._running:
|
||||
self._running = True
|
||||
if self._loop_thread is not None and self._loop_thread.is_alive():
|
||||
return
|
||||
self._stop_event.clear()
|
||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||
t.start()
|
||||
self._loop_thread = t
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
self._task_mgr.wake()
|
||||
if hasattr(self, "_loop_thread"):
|
||||
if self._loop_thread is not None:
|
||||
self._loop_thread.join(timeout=2.0)
|
||||
self._loop_thread = None
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
|
|
|
|||
Loading…
Reference in New Issue