fix: 修复 remove_task 未释放 KV cache slot 导致第二轮对话死锁
- remove_task() 现在释放 KV cache slot 和 prefix cache 引用 - _refill_active_batch 中 alloc 失败时将剩余 task 推回 waiting_queue - 主循环增加 try/except 异常兜底,发送 _STOP 给所有 task - 重构:server.py 全局变量改为 ServerState 类;automodel.py 使用 Registry 替代裸 dict;合并 TrainContextBuilder 的 with_* 方法到 build()
This commit is contained in:
parent
ffff05b2c6
commit
a6f5ff3b37
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||
|
||||
|
|
@ -12,8 +11,6 @@ import torch.nn as nn
|
|||
from astrai.inference.scheduler import _STOP, InferenceScheduler
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Inference scheduler for single-GPU continuous batching."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -12,6 +13,8 @@ from torch import Tensor
|
|||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STOP = object()
|
||||
|
||||
|
||||
|
|
@ -506,9 +509,20 @@ class InferenceScheduler:
|
|||
task_id: The task to remove.
|
||||
"""
|
||||
with self._lock:
|
||||
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||
|
||||
for task in removed_active:
|
||||
if task.prefix_len > 0:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
self.prefix_cache.release(prefix)
|
||||
if task.prefix_len < len(task.prompt_ids):
|
||||
self.prefix_cache.release(tuple(task.prompt_ids))
|
||||
if task.slot >= 0:
|
||||
self._free_slot(task.slot)
|
||||
task.slot = -1
|
||||
|
||||
def _remove_finished_tasks(self) -> None:
|
||||
"""Removes all finished tasks from the active batch.
|
||||
|
||||
|
|
@ -553,7 +567,7 @@ class InferenceScheduler:
|
|||
for _ in range(n):
|
||||
to_add.append(self.waiting_queue.pop(0))
|
||||
|
||||
for task in to_add:
|
||||
for i, task in enumerate(to_add):
|
||||
slot = -1
|
||||
reused = False
|
||||
if task.prefix_len > 0:
|
||||
|
|
@ -564,6 +578,8 @@ class InferenceScheduler:
|
|||
if slot < 0:
|
||||
slot = self._alloc_slot()
|
||||
if slot < 0:
|
||||
with self._lock:
|
||||
self.waiting_queue[:0] = to_add[i:]
|
||||
break
|
||||
task.slot = slot
|
||||
task.status = TaskStatus.RUNNING
|
||||
|
|
@ -712,6 +728,7 @@ class InferenceScheduler:
|
|||
Decode processes only the largest position group to ensure all
|
||||
batched tasks share the same KV cache write position.
|
||||
"""
|
||||
try:
|
||||
while self._running:
|
||||
self._remove_finished_tasks()
|
||||
self._refill_active_batch()
|
||||
|
|
@ -738,6 +755,15 @@ class InferenceScheduler:
|
|||
if not self.waiting_queue and len(self.active_tasks) <= 1:
|
||||
self._task_event.wait(timeout=0.005)
|
||||
self._task_event.clear()
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||
for task in self.active_tasks:
|
||||
if task.stream_callback:
|
||||
task.stream_callback(_STOP)
|
||||
for task in self.waiting_queue:
|
||||
if task.stream_callback:
|
||||
task.stream_callback(_STOP)
|
||||
raise
|
||||
|
||||
def start(self) -> None:
|
||||
"""Starts the background generation loop thread."""
|
||||
|
|
|
|||
|
|
@ -23,11 +23,22 @@ from astrai.tokenize import AutoTokenizer
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_engine: Optional[InferenceEngine] = None
|
||||
_model_param: Optional[Any] = None
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
_server_config: Dict[str, Any] = {
|
||||
|
||||
class ServerState:
|
||||
"""Encapsulates all server runtime state.
|
||||
|
||||
Attributes:
|
||||
engine: The inference engine instance.
|
||||
model_param: The loaded model.
|
||||
config: Server configuration dict.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine: Optional[InferenceEngine] = None
|
||||
self.model_param: Optional[Any] = None
|
||||
self.config: Dict[str, Any] = {
|
||||
"device": "cuda",
|
||||
"dtype": torch.bfloat16,
|
||||
"param_path": None,
|
||||
|
|
@ -35,34 +46,38 @@ _server_config: Dict[str, Any] = {
|
|||
}
|
||||
|
||||
|
||||
_state = ServerState()
|
||||
|
||||
|
||||
def configure_server(
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
_server_config["device"] = device
|
||||
_server_config["dtype"] = dtype
|
||||
_server_config["param_path"] = param_path
|
||||
_server_config["max_batch_size"] = max_batch_size
|
||||
_state.config.update(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
param_path=param_path,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global _model_param, _engine
|
||||
try:
|
||||
load_model(
|
||||
param_path=_server_config["param_path"],
|
||||
device=_server_config["device"],
|
||||
dtype=_server_config["dtype"],
|
||||
max_batch_size=_server_config["max_batch_size"],
|
||||
param_path=_state.config["param_path"],
|
||||
device=_state.config["device"],
|
||||
dtype=_state.config["dtype"],
|
||||
max_batch_size=_state.config["max_batch_size"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if _engine:
|
||||
_engine.shutdown()
|
||||
if _state.engine:
|
||||
_state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
|
|
@ -75,25 +90,30 @@ def load_model(
|
|||
dtype: torch.dtype = torch.bfloat16,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
global _model_param, _engine
|
||||
if param_path is None:
|
||||
param_path = _project_root / "params"
|
||||
if not param_path.exists():
|
||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
_model_param = AutoModel.from_pretrained(param_path)
|
||||
_model_param.to(device=device, dtype=dtype)
|
||||
_state.model_param = AutoModel.from_pretrained(param_path)
|
||||
_state.model_param.to(device=device, dtype=dtype)
|
||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||
|
||||
_engine = InferenceEngine(
|
||||
model=_model_param,
|
||||
_state.engine = InferenceEngine(
|
||||
model=_state.model_param,
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||
|
||||
|
||||
def _get_engine() -> InferenceEngine:
|
||||
if _state.engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return _state.engine
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
|
@ -121,30 +141,27 @@ class CompletionResponse(BaseModel):
|
|||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": _model_param is not None,
|
||||
"engine_ready": _engine is not None,
|
||||
"model_loaded": _state.model_param is not None,
|
||||
"engine_ready": _state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
if _engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return _engine.get_stats()
|
||||
return _get_engine().get_stats()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
if _engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
engine = _get_engine()
|
||||
|
||||
prompt = _engine.tokenizer.apply_chat_template(
|
||||
prompt = engine.tokenizer.apply_chat_template(
|
||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
agen = _engine.generate_async(
|
||||
agen = engine.generate_async(
|
||||
prompt=prompt,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
|
|
@ -163,7 +180,7 @@ async def chat_completion(request: ChatCompletionRequest):
|
|||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
else:
|
||||
result = _engine.generate(
|
||||
result = engine.generate(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
max_tokens=request.max_tokens,
|
||||
|
|
@ -198,8 +215,7 @@ async def generate(
|
|||
max_len: int = 2048,
|
||||
stream: bool = False,
|
||||
):
|
||||
if _engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
engine = _get_engine()
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
|
|
@ -209,10 +225,10 @@ async def generate(
|
|||
messages.append({"role": "assistant", "content": h[1]})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
if stream:
|
||||
agen = _engine.generate_async(
|
||||
agen = engine.generate_async(
|
||||
prompt=prompt,
|
||||
max_tokens=max_len,
|
||||
temperature=temperature,
|
||||
|
|
@ -226,7 +242,7 @@ async def generate(
|
|||
|
||||
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||
else:
|
||||
result = _engine.generate(
|
||||
result = engine.generate(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
max_tokens=max_len,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ AutoModel base class for model loading and saving.
|
|||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Self, Type, Union
|
||||
from typing import Self, Type, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.factory import Registry
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -44,8 +45,7 @@ class AutoModel(nn.Module):
|
|||
Provides model loading/saving and generation capabilities.
|
||||
"""
|
||||
|
||||
# Model registry - stored as class attribute
|
||||
_registry: Dict[str, Type["AutoModel"]] = {}
|
||||
_registry = Registry()
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
|
|
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
|
|||
"""
|
||||
|
||||
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
||||
cls._registry[model_type.lower()] = sub_cls
|
||||
cls._registry.register(model_type.lower(), sub_cls)
|
||||
return sub_cls
|
||||
|
||||
return decorator
|
||||
|
|
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
|
|||
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
||||
"""Get model class by model_type string."""
|
||||
model_type = model_type.lower()
|
||||
if model_type not in cls._registry:
|
||||
available = list(cls._registry.keys())
|
||||
if not cls._registry.contains(model_type):
|
||||
available = cls._registry.list_names()
|
||||
raise ValueError(
|
||||
f"Unknown model_type: {model_type}. Available: {available}"
|
||||
)
|
||||
return cls._registry[model_type]
|
||||
return cls._registry.get(model_type)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
|
@ -96,14 +96,8 @@ class AutoModel(nn.Module):
|
|||
else:
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
# If called from base class, use model_type to determine actual model class
|
||||
if cls is AutoModel:
|
||||
model_type = config.model_type or "transformer"
|
||||
actual_cls = cls.get_model_class(model_type)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call from_pretrained() on subclass {cls.__name__}"
|
||||
)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
model = actual_cls(config)
|
||||
|
|
|
|||
|
|
@ -34,66 +34,60 @@ class TrainContext:
|
|||
class TrainContextBuilder:
|
||||
def __init__(self, config: TrainConfig):
|
||||
self.config = config
|
||||
self._context = TrainContext(
|
||||
model=config.model,
|
||||
self._checkpoint: Optional[Checkpoint] = None
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
context = TrainContext(
|
||||
model=self.config.model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
)
|
||||
|
||||
device = get_current_device()
|
||||
self._context.model = self._context.model.to(device=device)
|
||||
context.model = context.model.to(device=device)
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
fn = self.config.parallel_wrapper
|
||||
self._context.model = fn(self._context.model)
|
||||
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
||||
context.model = self.config.parallel_wrapper(context.model)
|
||||
|
||||
self._context.optimizer = self.config.optimizer_fn(self._context.model)
|
||||
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
state_dict=self._context.model.state_dict(),
|
||||
)
|
||||
if self._checkpoint is not None:
|
||||
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
context.checkpoint = self._checkpoint
|
||||
else:
|
||||
# resume from the assigned checkpoint or assigned iteration
|
||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||
self._context.model.load_state_dict(checkpoint.state_dict)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=context.model.state_dict(),
|
||||
)
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
context.optimizer = self.config.optimizer_fn(context.model)
|
||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||
|
||||
def with_dataloader(self) -> Self:
|
||||
# fix: change batch level iteration to sample level offset
|
||||
config = self.config
|
||||
sampler_offset = self._context.iteration * config.batch_size
|
||||
resumeable_sampler = ResumableDistributedSampler(
|
||||
data_source=config.dataset,
|
||||
start_epoch=self._context.epoch,
|
||||
cfg = self.config
|
||||
sampler_offset = context.iteration * cfg.batch_size
|
||||
sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.dataset,
|
||||
start_epoch=context.epoch,
|
||||
start_iter=sampler_offset,
|
||||
seed=config.random_seed,
|
||||
seed=cfg.random_seed,
|
||||
)
|
||||
context.dataloader = DataLoader(
|
||||
cfg.dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=cfg.pin_memory,
|
||||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
config.dataset,
|
||||
batch_size=config.batch_size,
|
||||
sampler=resumeable_sampler,
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=config.pin_memory,
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
)
|
||||
self._context.dataloader = dataloader
|
||||
return self
|
||||
|
||||
def with_strategy(self) -> Self:
|
||||
self._context.strategy = StrategyFactory.create(
|
||||
model=self._context.model,
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=self.config.strategy,
|
||||
device=get_current_device(),
|
||||
device=device,
|
||||
**self.config.extra_kwargs,
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
return self._context
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -35,11 +35,7 @@ class Trainer:
|
|||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
return (
|
||||
TrainContextBuilder(self.train_config)
|
||||
.with_checkpoint(checkpoint)
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.build()
|
||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||
)
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
|
|
|
|||
|
|
@ -53,5 +53,5 @@ def mock_engine():
|
|||
@pytest.fixture
|
||||
def loaded_model(mock_model_param, monkeypatch):
|
||||
"""Simulate that the model is loaded."""
|
||||
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
|
||||
monkeypatch.setattr("astrai.inference.server._state.model_param", mock_model_param)
|
||||
return mock_model_param
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import pytest
|
|||
|
||||
def test_health_no_model(client, monkeypatch):
|
||||
"""GET /health should return 200 even when model not loaded."""
|
||||
monkeypatch.setattr("astrai.inference.server._model_param", None)
|
||||
monkeypatch.setattr("astrai.inference.server._engine", None)
|
||||
monkeypatch.setattr("astrai.inference.server._state.model_param", None)
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", None)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -17,7 +17,7 @@ def test_health_no_model(client, monkeypatch):
|
|||
|
||||
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
|
||||
"""GET /health should return 200 when model is loaded."""
|
||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -28,7 +28,7 @@ def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
|
|||
|
||||
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||
"""POST /generate with stream=false should return JSON response."""
|
||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/generate",
|
||||
params={
|
||||
|
|
@ -54,7 +54,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
|
|||
yield "chunk2"
|
||||
|
||||
mock_engine.generate.return_value = stream_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/generate",
|
||||
params={
|
||||
|
|
@ -78,7 +78,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
|
|||
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||
"""POST /v1/chat/completions with stream=false returns OpenAI‑style JSON."""
|
||||
mock_engine.generate.return_value = "Assistant reply"
|
||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
|
|
@ -106,7 +106,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
|
|||
yield "[DONE]"
|
||||
|
||||
mock_engine.generate_async.return_value = async_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
|
|
@ -132,7 +132,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
|
|||
|
||||
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
|
||||
"""POST /generate with history parameter."""
|
||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/generate",
|
||||
params={
|
||||
|
|
|
|||
Loading…
Reference in New Issue