fix: 使用 threading.Event 替代裸 bool,补全公共 API
- scheduler 停止信号改用 threading.Event,跨解释器安全 - 移除 _fatal_error 和 check_health,异常仅用 logger.error 记录 - 补全 astrai/__init__.py,暴露所有主要模块
This commit is contained in:
parent
7a04b1f8ce
commit
3e234c46f6
|
|
@ -3,32 +3,98 @@ __author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
AutoRegressiveLMConfig,
|
AutoRegressiveLMConfig,
|
||||||
|
BaseModelConfig,
|
||||||
|
ConfigFactory,
|
||||||
EncoderConfig,
|
EncoderConfig,
|
||||||
|
PipelineConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import (
|
||||||
|
BaseDataset,
|
||||||
|
DatasetFactory,
|
||||||
|
ResumableDistributedSampler,
|
||||||
|
Store,
|
||||||
|
StoreFactory,
|
||||||
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.inference import (
|
from astrai.inference import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
|
ProtocolHandler,
|
||||||
|
SamplingPipeline,
|
||||||
|
get_app,
|
||||||
|
run_server,
|
||||||
|
sample,
|
||||||
|
)
|
||||||
|
from astrai.model import (
|
||||||
|
AutoModel,
|
||||||
|
AutoRegressiveLM,
|
||||||
|
EmbeddingEncoder,
|
||||||
|
LoRAConfig,
|
||||||
|
inject_lora,
|
||||||
|
)
|
||||||
|
from astrai.parallel import (
|
||||||
|
ExecutorFactory,
|
||||||
|
get_rank,
|
||||||
|
get_world_size,
|
||||||
|
only_on_rank,
|
||||||
|
spawn_parallel_fn,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing import Pipeline, filter_by_length
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
|
from astrai.tokenize import AutoTokenizer, ChatTemplate
|
||||||
|
from astrai.trainer import (
|
||||||
|
BaseScheduler,
|
||||||
|
BaseStrategy,
|
||||||
|
CallbackFactory,
|
||||||
|
Muon,
|
||||||
|
SchedulerFactory,
|
||||||
|
StrategyFactory,
|
||||||
|
TrainCallback,
|
||||||
|
Trainer,
|
||||||
)
|
)
|
||||||
from astrai.model import AutoModel, AutoRegressiveLM
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AutoRegressiveLM",
|
"AutoRegressiveLM",
|
||||||
"AutoRegressiveLMConfig",
|
"AutoRegressiveLMConfig",
|
||||||
"EncoderConfig",
|
"AutoModel",
|
||||||
"TrainConfig",
|
|
||||||
"DatasetFactory",
|
|
||||||
"AutoTokenizer",
|
"AutoTokenizer",
|
||||||
|
"BaseDataset",
|
||||||
|
"BaseFactory",
|
||||||
|
"BaseModelConfig",
|
||||||
|
"BaseScheduler",
|
||||||
|
"BaseStrategy",
|
||||||
|
"CallbackFactory",
|
||||||
|
"ChatTemplate",
|
||||||
|
"Checkpoint",
|
||||||
|
"ConfigFactory",
|
||||||
|
"DatasetFactory",
|
||||||
|
"EmbeddingEncoder",
|
||||||
|
"EncoderConfig",
|
||||||
|
"ExecutorFactory",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"Trainer",
|
"LoRAConfig",
|
||||||
"CallbackFactory",
|
"Muon",
|
||||||
"StrategyFactory",
|
"Pipeline",
|
||||||
|
"PipelineConfig",
|
||||||
|
"ProtocolHandler",
|
||||||
|
"ResumableDistributedSampler",
|
||||||
|
"SamplingPipeline",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
"BaseFactory",
|
"Store",
|
||||||
"AutoModel",
|
"StoreFactory",
|
||||||
|
"StrategyFactory",
|
||||||
|
"TrainCallback",
|
||||||
|
"TrainConfig",
|
||||||
|
"Trainer",
|
||||||
|
"filter_by_length",
|
||||||
|
"get_app",
|
||||||
|
"get_rank",
|
||||||
|
"get_world_size",
|
||||||
|
"inject_lora",
|
||||||
|
"only_on_rank",
|
||||||
|
"run_server",
|
||||||
|
"sample",
|
||||||
|
"spawn_parallel_fn",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -70,8 +70,8 @@ class InferenceScheduler:
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._stop_event = threading.Event()
|
||||||
self._fatal_error: Optional[Exception] = None
|
self._loop_thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
def add_task(self, prompt: str, **kwargs) -> str:
|
def add_task(self, prompt: str, **kwargs) -> str:
|
||||||
return self._task_mgr.add_task(prompt, **kwargs)
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
@ -86,7 +86,7 @@ class InferenceScheduler:
|
||||||
def _run_generation_loop(self):
|
def _run_generation_loop(self):
|
||||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while not self._stop_event.is_set():
|
||||||
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
||||||
for task in finished:
|
for task in finished:
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
@ -176,8 +176,7 @@ class InferenceScheduler:
|
||||||
t.stream_callback(STOP)
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._fatal_error = e
|
self._stop_event.set()
|
||||||
self._running = False
|
|
||||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
|
|
@ -189,17 +188,19 @@ class InferenceScheduler:
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
if not self._running:
|
if self._loop_thread is not None and self._loop_thread.is_alive():
|
||||||
self._running = True
|
return
|
||||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
self._stop_event.clear()
|
||||||
t.start()
|
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||||
self._loop_thread = t
|
t.start()
|
||||||
|
self._loop_thread = t
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._running = False
|
self._stop_event.set()
|
||||||
self._task_mgr.wake()
|
self._task_mgr.wake()
|
||||||
if hasattr(self, "_loop_thread"):
|
if self._loop_thread is not None:
|
||||||
self._loop_thread.join(timeout=2.0)
|
self._loop_thread.join(timeout=2.0)
|
||||||
|
self._loop_thread = None
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue