diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 95c878f..e7a2d48 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -2,7 +2,6 @@ import asyncio import gc -import logging import threading from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union @@ -12,8 +11,6 @@ import torch.nn as nn from astrai.inference.scheduler import _STOP, InferenceScheduler from astrai.tokenize import AutoTokenizer -logger = logging.getLogger(__name__) - class GenerationRequest: """Request parameters for text generation. diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index e32bb4a..c9b690e 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -1,5 +1,6 @@ """Inference scheduler for single-GPU continuous batching.""" +import logging import threading import time import uuid @@ -12,6 +13,8 @@ from torch import Tensor from astrai.model.automodel import AutoModel from astrai.tokenize import AutoTokenizer +logger = logging.getLogger(__name__) + _STOP = object() @@ -506,9 +509,20 @@ class InferenceScheduler: task_id: The task to remove. """ with self._lock: + removed_active = [t for t in self.active_tasks if t.task_id == task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] + for task in removed_active: + if task.prefix_len > 0: + prefix = tuple(task.prompt_ids[: task.prefix_len]) + self.prefix_cache.release(prefix) + if task.prefix_len < len(task.prompt_ids): + self.prefix_cache.release(tuple(task.prompt_ids)) + if task.slot >= 0: + self._free_slot(task.slot) + task.slot = -1 + def _remove_finished_tasks(self) -> None: """Removes all finished tasks from the active batch. @@ -553,7 +567,7 @@ class InferenceScheduler: for _ in range(n): to_add.append(self.waiting_queue.pop(0)) - for task in to_add: + for i, task in enumerate(to_add): slot = -1 reused = False if task.prefix_len > 0: @@ -564,6 +578,8 @@ class InferenceScheduler: if slot < 0: slot = self._alloc_slot() if slot < 0: + with self._lock: + self.waiting_queue[:0] = to_add[i:] break task.slot = slot task.status = TaskStatus.RUNNING @@ -712,32 +728,42 @@ class InferenceScheduler: Decode processes only the largest position group to ensure all batched tasks share the same KV cache write position. """ - while self._running: - self._remove_finished_tasks() - self._refill_active_batch() + try: + while self._running: + self._remove_finished_tasks() + self._refill_active_batch() - with self._lock: - if not self.active_tasks and not self.waiting_queue: + with self._lock: + if not self.active_tasks and not self.waiting_queue: + self._task_event.clear() + self._task_event.wait(timeout=0.01) + continue + tasks = self.active_tasks[:] + + to_prefill = [t for t in tasks if t.output_tokens == 0] + if to_prefill: + self._execute_prefill(to_prefill) + + pos_groups: Dict[int, List[Task]] = {} + for t in self.active_tasks: + pos_groups.setdefault(t.next_pos, []).append(t) + + if pos_groups: + best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) + self._execute_decode(pos_groups[best_pos], best_pos) + + if not self.waiting_queue and len(self.active_tasks) <= 1: + self._task_event.wait(timeout=0.005) self._task_event.clear() - self._task_event.wait(timeout=0.01) - continue - tasks = self.active_tasks[:] - - to_prefill = [t for t in tasks if t.output_tokens == 0] - if to_prefill: - self._execute_prefill(to_prefill) - - pos_groups: Dict[int, List[Task]] = {} - for t in self.active_tasks: - pos_groups.setdefault(t.next_pos, []).append(t) - - if pos_groups: - best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) - self._execute_decode(pos_groups[best_pos], best_pos) - - if not self.waiting_queue and len(self.active_tasks) <= 1: - self._task_event.wait(timeout=0.005) - self._task_event.clear() + except Exception as e: + logger.error(f"Scheduler loop crashed: {e}", exc_info=True) + for task in self.active_tasks: + if task.stream_callback: + task.stream_callback(_STOP) + for task in self.waiting_queue: + if task.stream_callback: + task.stream_callback(_STOP) + raise def start(self) -> None: """Starts the background generation loop thread.""" diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 0cd41bc..71f2739 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -23,16 +23,30 @@ from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) -_engine: Optional[InferenceEngine] = None -_model_param: Optional[Any] = None _project_root = Path(__file__).parent.parent.parent -_server_config: Dict[str, Any] = { - "device": "cuda", - "dtype": torch.bfloat16, - "param_path": None, - "max_batch_size": 16, -} + +class ServerState: + """Encapsulates all server runtime state. + + Attributes: + engine: The inference engine instance. + model_param: The loaded model. + config: Server configuration dict. + """ + + def __init__(self): + self.engine: Optional[InferenceEngine] = None + self.model_param: Optional[Any] = None + self.config: Dict[str, Any] = { + "device": "cuda", + "dtype": torch.bfloat16, + "param_path": None, + "max_batch_size": 16, + } + + +_state = ServerState() def configure_server( @@ -41,28 +55,29 @@ def configure_server( param_path: Optional[Path] = None, max_batch_size: int = 16, ): - _server_config["device"] = device - _server_config["dtype"] = dtype - _server_config["param_path"] = param_path - _server_config["max_batch_size"] = max_batch_size + _state.config.update( + device=device, + dtype=dtype, + param_path=param_path, + max_batch_size=max_batch_size, + ) @asynccontextmanager async def lifespan(app: FastAPI): - global _model_param, _engine try: load_model( - param_path=_server_config["param_path"], - device=_server_config["device"], - dtype=_server_config["dtype"], - max_batch_size=_server_config["max_batch_size"], + param_path=_state.config["param_path"], + device=_state.config["device"], + dtype=_state.config["dtype"], + max_batch_size=_state.config["max_batch_size"], ) except Exception as e: logger.error(f"Failed to load model: {e}") raise yield - if _engine: - _engine.shutdown() + if _state.engine: + _state.engine.shutdown() logger.info("Inference engine shutdown complete") @@ -75,25 +90,30 @@ def load_model( dtype: torch.dtype = torch.bfloat16, max_batch_size: int = 16, ): - global _model_param, _engine if param_path is None: param_path = _project_root / "params" if not param_path.exists(): raise FileNotFoundError(f"Parameter directory not found: {param_path}") tokenizer = AutoTokenizer.from_pretrained(param_path) - _model_param = AutoModel.from_pretrained(param_path) - _model_param.to(device=device, dtype=dtype) + _state.model_param = AutoModel.from_pretrained(param_path) + _state.model_param.to(device=device, dtype=dtype) logger.info(f"Model loaded on {device} with dtype {dtype}") - _engine = InferenceEngine( - model=_model_param, + _state.engine = InferenceEngine( + model=_state.model_param, tokenizer=tokenizer, max_batch_size=max_batch_size, ) logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") +def _get_engine() -> InferenceEngine: + if _state.engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + return _state.engine + + class ChatMessage(BaseModel): role: str content: str @@ -121,30 +141,27 @@ class CompletionResponse(BaseModel): async def health(): return { "status": "ok", - "model_loaded": _model_param is not None, - "engine_ready": _engine is not None, + "model_loaded": _state.model_param is not None, + "engine_ready": _state.engine is not None, } @app.get("/stats") async def get_stats(): - if _engine is None: - raise HTTPException(status_code=503, detail="Engine not initialized") - return _engine.get_stats() + return _get_engine().get_stats() @app.post("/v1/chat/completions", response_model=CompletionResponse) async def chat_completion(request: ChatCompletionRequest): - if _engine is None: - raise HTTPException(status_code=503, detail="Engine not initialized") + engine = _get_engine() - prompt = _engine.tokenizer.apply_chat_template( + prompt = engine.tokenizer.apply_chat_template( [{"role": m.role, "content": m.content} for m in request.messages], tokenize=False, ) if request.stream: - agen = _engine.generate_async( + agen = engine.generate_async( prompt=prompt, max_tokens=request.max_tokens, temperature=request.temperature, @@ -163,7 +180,7 @@ async def chat_completion(request: ChatCompletionRequest): headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) else: - result = _engine.generate( + result = engine.generate( prompt=prompt, stream=False, max_tokens=request.max_tokens, @@ -198,8 +215,7 @@ async def generate( max_len: int = 2048, stream: bool = False, ): - if _engine is None: - raise HTTPException(status_code=503, detail="Engine not initialized") + engine = _get_engine() messages = [] if history: @@ -209,10 +225,10 @@ async def generate( messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "user", "content": query}) - prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False) + prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) if stream: - agen = _engine.generate_async( + agen = engine.generate_async( prompt=prompt, max_tokens=max_len, temperature=temperature, @@ -226,7 +242,7 @@ async def generate( return StreamingResponse(text_stream(), media_type="text/plain") else: - result = _engine.generate( + result = engine.generate( prompt=prompt, stream=False, max_tokens=max_len, diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index 8e4d5e9..6fa86b8 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -4,12 +4,13 @@ AutoModel base class for model loading and saving. from contextlib import contextmanager from pathlib import Path -from typing import Dict, Self, Type, Union +from typing import Self, Type, Union import safetensors.torch as st import torch.nn as nn from astrai.config import ModelConfig +from astrai.factory import Registry @contextmanager @@ -44,8 +45,7 @@ class AutoModel(nn.Module): Provides model loading/saving and generation capabilities. """ - # Model registry - stored as class attribute - _registry: Dict[str, Type["AutoModel"]] = {} + _registry = Registry() def __init__(self, config: ModelConfig): super().__init__() @@ -63,7 +63,7 @@ class AutoModel(nn.Module): """ def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]: - cls._registry[model_type.lower()] = sub_cls + cls._registry.register(model_type.lower(), sub_cls) return sub_cls return decorator @@ -72,12 +72,12 @@ class AutoModel(nn.Module): def get_model_class(cls, model_type: str) -> Type["AutoModel"]: """Get model class by model_type string.""" model_type = model_type.lower() - if model_type not in cls._registry: - available = list(cls._registry.keys()) + if not cls._registry.contains(model_type): + available = cls._registry.list_names() raise ValueError( f"Unknown model_type: {model_type}. Available: {available}" ) - return cls._registry[model_type] + return cls._registry.get(model_type) @classmethod def from_pretrained( @@ -96,14 +96,8 @@ class AutoModel(nn.Module): else: raise FileNotFoundError(f"Config file not found: {config_path}") - # If called from base class, use model_type to determine actual model class - if cls is AutoModel: - model_type = config.model_type or "transformer" - actual_cls = cls.get_model_class(model_type) - else: - raise ValueError( - f"Cannot call from_pretrained() on subclass {cls.__name__}" - ) + model_type = config.model_type or "transformer" + actual_cls = cls.get_model_class(model_type) with _disable_random_init(enable=disable_random_init): model = actual_cls(config) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 54d7319..04eeb27 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -34,66 +34,60 @@ class TrainContext: class TrainContextBuilder: def __init__(self, config: TrainConfig): self.config = config - self._context = TrainContext( - model=config.model, + self._checkpoint: Optional[Checkpoint] = None + + def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: + self._checkpoint = checkpoint + return self + + def build(self) -> TrainContext: + context = TrainContext( + model=self.config.model, world_size=get_world_size(), rank=get_rank(), ) device = get_current_device() - self._context.model = self._context.model.to(device=device) + context.model = context.model.to(device=device) - if self.config.nprocs > 1: - fn = self.config.parallel_wrapper - self._context.model = fn(self._context.model) + if self.config.nprocs > 1 and self.config.parallel_wrapper: + context.model = self.config.parallel_wrapper(context.model) - self._context.optimizer = self.config.optimizer_fn(self._context.model) - self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) - - def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: - if checkpoint is None: - checkpoint = Checkpoint( - state_dict=self._context.model.state_dict(), - ) + if self._checkpoint is not None: + context.epoch = max(self._checkpoint.epoch, self.config.start_epoch) + context.iteration = max(self._checkpoint.iteration, self.config.start_batch) + context.model.load_state_dict(self._checkpoint.state_dict) + context.checkpoint = self._checkpoint else: - # resume from the assigned checkpoint or assigned iteration - self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) - self._context.iteration = max(checkpoint.iteration, self.config.start_batch) - self._context.model.load_state_dict(checkpoint.state_dict) + context.checkpoint = Checkpoint( + state_dict=context.model.state_dict(), + ) - self._context.checkpoint = checkpoint - return self + context.optimizer = self.config.optimizer_fn(context.model) + context.scheduler = self.config.scheduler_fn(context.optimizer) - def with_dataloader(self) -> Self: - # fix: change batch level iteration to sample level offset - config = self.config - sampler_offset = self._context.iteration * config.batch_size - resumeable_sampler = ResumableDistributedSampler( - data_source=config.dataset, - start_epoch=self._context.epoch, + cfg = self.config + sampler_offset = context.iteration * cfg.batch_size + sampler = ResumableDistributedSampler( + data_source=cfg.dataset, + start_epoch=context.epoch, start_iter=sampler_offset, - seed=config.random_seed, + seed=cfg.random_seed, + ) + context.dataloader = DataLoader( + cfg.dataset, + batch_size=cfg.batch_size, + sampler=sampler, + num_workers=cfg.num_workers, + pin_memory=cfg.pin_memory, + prefetch_factor=cfg.prefetch_factor, ) - dataloader = DataLoader( - config.dataset, - batch_size=config.batch_size, - sampler=resumeable_sampler, - num_workers=config.num_workers, - pin_memory=config.pin_memory, - prefetch_factor=config.prefetch_factor, - ) - self._context.dataloader = dataloader - return self - - def with_strategy(self) -> Self: - self._context.strategy = StrategyFactory.create( - model=self._context.model, + context.strategy = StrategyFactory.create( + model=context.model, train_type=self.config.strategy, - device=get_current_device(), + device=device, **self.config.extra_kwargs, ) - return self - def build(self) -> TrainContext: - return self._context + return context diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 9831545..b7f2361 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -35,11 +35,7 @@ class Trainer: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: return ( - TrainContextBuilder(self.train_config) - .with_checkpoint(checkpoint) - .with_dataloader() - .with_strategy() - .build() + TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build() ) def _call_callbacks(self, method_name: str, context: TrainContext): diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py index d17d21b..6dc2d8e 100644 --- a/tests/inference/conftest.py +++ b/tests/inference/conftest.py @@ -53,5 +53,5 @@ def mock_engine(): @pytest.fixture def loaded_model(mock_model_param, monkeypatch): """Simulate that the model is loaded.""" - monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param) + monkeypatch.setattr("astrai.inference.server._state.model_param", mock_model_param) return mock_model_param diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index 5bdcace..45f0895 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -5,8 +5,8 @@ import pytest def test_health_no_model(client, monkeypatch): """GET /health should return 200 even when model not loaded.""" - monkeypatch.setattr("astrai.inference.server._model_param", None) - monkeypatch.setattr("astrai.inference.server._engine", None) + monkeypatch.setattr("astrai.inference.server._state.model_param", None) + monkeypatch.setattr("astrai.inference.server._state.engine", None) response = client.get("/health") assert response.status_code == 200 data = response.json() @@ -17,7 +17,7 @@ def test_health_no_model(client, monkeypatch): def test_health_with_model(client, loaded_model, mock_engine, monkeypatch): """GET /health should return 200 when model is loaded.""" - monkeypatch.setattr("astrai.inference.server._engine", mock_engine) + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) response = client.get("/health") assert response.status_code == 200 data = response.json() @@ -28,7 +28,7 @@ def test_health_with_model(client, loaded_model, mock_engine, monkeypatch): def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch): """POST /generate with stream=false should return JSON response.""" - monkeypatch.setattr("astrai.inference.server._engine", mock_engine) + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) response = client.post( "/generate", params={ @@ -54,7 +54,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch): yield "chunk2" mock_engine.generate.return_value = stream_gen() - monkeypatch.setattr("astrai.inference.server._engine", mock_engine) + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) response = client.post( "/generate", params={ @@ -78,7 +78,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch): def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch): """POST /v1/chat/completions with stream=false returns OpenAI‑style JSON.""" mock_engine.generate.return_value = "Assistant reply" - monkeypatch.setattr("astrai.inference.server._engine", mock_engine) + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) response = client.post( "/v1/chat/completions", json={ @@ -106,7 +106,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch) yield "[DONE]" mock_engine.generate_async.return_value = async_gen() - monkeypatch.setattr("astrai.inference.server._engine", mock_engine) + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) response = client.post( "/v1/chat/completions", json={ @@ -132,7 +132,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch) def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch): """POST /generate with history parameter.""" - monkeypatch.setattr("astrai.inference.server._engine", mock_engine) + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) response = client.post( "/generate", params={